Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ protected Compressor initialValue() {

private final boolean dataPushFailureTrackingEnabled;

private final boolean shuffleWriteLimitEnabled;

public static class ReduceFileGroups {
public Map<Integer, Set<PartitionLocation>> partitionGroups;
public Map<String, LocationPushFailedBatches> pushFailedBatches;
Expand Down Expand Up @@ -211,6 +213,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
}
authEnabled = conf.authEnabledOnClient();
dataPushFailureTrackingEnabled = conf.clientAdaptiveOptimizeSkewedPartitionReadEnabled();
shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled();

// init rpc env
rpcEnv =
Expand Down Expand Up @@ -1067,6 +1070,10 @@ public int pushOrMergeData(
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length);
System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length);

if (shuffleWriteLimitEnabled) {
pushState.addWrittenBytes(body.length);
}

if (doPush) {
// check limit
limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort());
Expand Down Expand Up @@ -1789,6 +1796,8 @@ private void mapEndInternal(
long[] bytesPerPartition =
pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions);

long bytesWritten = pushState.getBytesWritten();

MapperEndResponse response =
lifecycleManagerRef.askSync(
new MapperEnd(
Expand All @@ -1801,7 +1810,8 @@ private void mapEndInternal(
numPartitions,
crc32PerPartition,
bytesPerPartition,
SerdeVersion.V1),
SerdeVersion.V1,
bytesWritten),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(MapperEndResponse.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, LongAdder}
import java.util.function.{BiConsumer, BiFunction, Consumer}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -132,6 +132,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private val mockDestroyFailure = conf.testMockDestroySlotsFailure
private val authEnabled = conf.authEnabledOnClient
private var applicationMeta: ApplicationMeta = _

private val shuffleWriteLimitEnabled = conf.shuffleWriteLimitEnabled
private val shuffleWriteLimitThreshold = conf.shuffleWriteLimitThreshold
private val shuffleTotalWrittenBytes = JavaUtils.newConcurrentHashMap[Int, AtomicLong]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should clean shuffleId related data when shuffle expires.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.


@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[String, ShufflePartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
Expand Down Expand Up @@ -439,7 +444,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition,
serdeVersion) =>
serdeVersion,
bytesWritten) =>
logTrace(s"Received MapperEnd TaskEnd request, " +
s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
val partitionType = getPartitionType(shuffleId)
Expand All @@ -455,7 +461,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition,
serdeVersion)
serdeVersion,
bytesWritten)
case PartitionType.MAP =>
handleMapPartitionEnd(
context,
Expand Down Expand Up @@ -933,7 +940,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long],
serdeVersion: SerdeVersion): Unit = {
serdeVersion: SerdeVersion,
bytesWritten: Long): Unit = {

val (mapperAttemptFinishedSuccess, allMapperFinished) =
commitManager.finishMapperAttempt(
Expand All @@ -945,6 +953,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
numPartitions = numPartitions,
crc32PerPartition = crc32PerPartition,
bytesWrittenPerPartition = bytesWrittenPerPartition)

if (mapperAttemptFinishedSuccess && shuffleWriteLimitEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider adding a negative test case for this feature?

handleShuffleWriteLimitCheck(shuffleId, bytesWritten)
logDebug(s"Shuffle $shuffleId, mapId: $mapId, attemptId: $attemptId, " +
s"map written bytes: $bytesWritten, shuffle total written bytes: ${shuffleTotalWrittenBytes.get(
shuffleId).get()}, write limit threshold: $shuffleWriteLimitThreshold")
}

if (mapperAttemptFinishedSuccess && allMapperFinished) {
// last mapper finished. call mapper end
logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" +
Expand Down Expand Up @@ -1258,6 +1274,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
logInfo(s"[handleUnregisterShuffle] Wait for handleStageEnd complete costs ${cost}ms")
}
}

if (shuffleWriteLimitEnabled) {
shuffleTotalWrittenBytes.remove(shuffleId)
}
}

// add shuffleKey to delay shuffle removal set
Expand Down Expand Up @@ -2081,4 +2101,25 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}

def getShuffleIdMapping = shuffleIdMapping

private def handleShuffleWriteLimitCheck(shuffleId: Int, writtenBytes: Long): Unit = {
if (!shuffleWriteLimitEnabled || shuffleWriteLimitThreshold <= 0) return

if (writtenBytes > 0) {
val totalBytesAccumulator =
shuffleTotalWrittenBytes.computeIfAbsent(shuffleId, (id: Int) => new AtomicLong(0))
val currentTotalBytes = totalBytesAccumulator.addAndGet(writtenBytes)

if (currentTotalBytes > shuffleWriteLimitThreshold) {
val reason =
s"Shuffle $shuffleId exceeded write limit threshold: current total ${currentTotalBytes} bytes, max allowed ${shuffleWriteLimitThreshold} bytes"
logError(reason)

cancelShuffleCallback match {
case Some(c) => c.accept(shuffleId, reason)
case _ => None
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ public class PushState {

private final Map<String, LocationPushFailedBatches> failedBatchMap;

private long bytesWritten;

public PushState(CelebornConf conf) {
pushBufferMaxSize = conf.clientPushBufferMaxSize();
inFlightRequestTracker = new InFlightRequestTracker(conf, this);
failedBatchMap = JavaUtils.newConcurrentHashMap();
bytesWritten = 0;
}

public void cleanup() {
Expand Down Expand Up @@ -136,4 +139,12 @@ public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset,
commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata());
commitMetadata.addDataWithOffsetAndLength(data, offset, length);
}

public void addWrittenBytes(int length) {
bytesWritten += length;
}

public long getBytesWritten() {
return bytesWritten;
}
}
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ message PbMapperEnd {
int32 numPartitions = 7;
repeated int32 crc32PerPartition = 8;
repeated int64 bytesWrittenPerPartition = 9;
int64 bytesWritten = 10;
}

message PbLocationPushFailedBatches {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def secretRedactionPattern = get(SECRET_REDACTION_PATTERN)

def containerInfoProviderClass = get(CONTAINER_INFO_PROVIDER)

def shuffleWriteLimitEnabled: Boolean = get(SHUFFLE_WRITE_LIMIT_ENABLED)

def shuffleWriteLimitThreshold: Long = get(SHUFFLE_WRITE_LIMIT_THRESHOLD)
}

object CelebornConf extends Logging {
Expand Down Expand Up @@ -6826,4 +6830,19 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)

val SHUFFLE_WRITE_LIMIT_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.spark.shuffle.write.limit.enabled")
.categories("client")
.doc("Enable shuffle write limit check to prevent cluster resource exhaustion.")
.version("0.7.0")
.booleanConf
.createWithDefault(false)

val SHUFFLE_WRITE_LIMIT_THRESHOLD: ConfigEntry[Long] =
buildConf("celeborn.client.spark.shuffle.write.limit.threshold")
.categories("client")
.doc("Shuffle write limit threshold, exceed to cancel oversized shuffle.")
.version("0.7.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("5TB")
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ object ControlMessages extends Logging {
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long],
serdeVersion: SerdeVersion)
serdeVersion: SerdeVersion,
bytesWritten: Long)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this patch is changing the protocol

how is this handled ?
will the new client be compatible with the new server version and viceversa ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big deal since MapperEnd is only utilized on the engine side.

extends MasterMessage

case class ReadReducerPartitionEnd(
Expand Down Expand Up @@ -737,7 +738,8 @@ object ControlMessages extends Logging {
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition,
serdeVersion) =>
serdeVersion,
bytesWritten) =>
val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
(k, resultValue)
Expand All @@ -753,6 +755,7 @@ object ControlMessages extends Logging {
.addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava)
.addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map(
java.lang.Long.valueOf).toSeq.asJava)
.setBytesWritten(bytesWritten)
.build().toByteArray
new TransportMessage(MessageType.MAPPER_END, payload, serdeVersion)

Expand Down Expand Up @@ -1248,7 +1251,8 @@ object ControlMessages extends Logging {
pbMapperEnd.getNumPartitions,
crc32Array,
bytesWrittenPerPartitionArray,
message.getSerdeVersion)
message.getSerdeVersion,
pbMapperEnd.getBytesWritten)

case READ_REDUCER_PARTITION_END_VALUE =>
val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ class UtilsSuite extends CelebornFunSuite {
1,
Array.emptyIntArray,
Array.emptyLongArray,
SerdeVersion.V1)
SerdeVersion.V1,
1)
val mapperEndTrans =
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId)
Expand All @@ -172,6 +173,7 @@ class UtilsSuite extends CelebornFunSuite {
assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions)
mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition
mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition
assert(mapperEnd.bytesWritten == mapperEndTrans.bytesWritten)
}

test("validate HDFS compatible fs path") {
Expand Down
2 changes: 2 additions & 0 deletions docs/configuration/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ license: |
| celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled |
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | |
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | |
| celeborn.client.spark.shuffle.write.limit.enabled | false | false | Enable shuffle write limit check to prevent cluster resource exhaustion. | 0.7.0 | |
| celeborn.client.spark.shuffle.write.limit.threshold | 5TB | false | Shuffle write limit threshold, exceed to cancel oversized shuffle. | 0.7.0 | |
| celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer |
| celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure |
| celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider |
Expand Down
5 changes: 5 additions & 0 deletions tests/spark-it/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@
<artifactId>minio</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
package org.apache.celeborn.tests.client

import java.util
import java.util.Collections
import java.util.function.BiConsumer

import org.mockito.Mockito
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.{interval, timeout}
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime

import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.network.protocol.SerdeVersion
import org.apache.celeborn.common.protocol.message.ControlMessages.MapperEnd
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.service.deploy.MiniClusterFeature

class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeature {
Expand Down Expand Up @@ -126,6 +132,68 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu
}
}

test("CELEBORN-2264: Support cancel shuffle when write bytes exceeds threshold") {
val conf = celebornConf.clone
conf.set(CelebornConf.SHUFFLE_WRITE_LIMIT_ENABLED.key, "true")
.set(CelebornConf.SHUFFLE_WRITE_LIMIT_THRESHOLD.key, "2000")
val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf)
val ctx = Mockito.mock(classOf[RpcCallContext])

// Custom BiConsumer callback to track if cancelShuffle is invoked
var isCancelShuffleInvoked = false
val cancelShuffleCallback = new BiConsumer[Integer, String] {
override def accept(shuffleId: Integer, reason: String): Unit = {
isCancelShuffleInvoked = true
}
}
lifecycleManager.registerCancelShuffleCallback(cancelShuffleCallback)

// Scenario 1: Same mapper with multiple attempts (total bytes exceed threshold but no cancel)
val shuffleId = 0
val mapId1 = 0
lifecycleManager.receiveAndReply(ctx)(MapperEnd(
shuffleId = shuffleId,
mapId = mapId1,
attemptId = 0,
2,
1,
Collections.emptyMap(),
1,
Array.emptyIntArray,
Array.emptyLongArray,
SerdeVersion.V1,
bytesWritten = 1500))
lifecycleManager.receiveAndReply(ctx)(MapperEnd(
shuffleId = shuffleId,
mapId = mapId1,
attemptId = 1,
2,
1,
Collections.emptyMap(),
1,
Array.emptyIntArray,
Array.emptyLongArray,
SerdeVersion.V1,
bytesWritten = 1500))
assert(!isCancelShuffleInvoked)

// Scenario 2: Total bytes of mapId1 + mapId2 exceed threshold (trigger cancel)
val mapId2 = 1
lifecycleManager.receiveAndReply(ctx)(MapperEnd(
shuffleId = shuffleId,
mapId = mapId2,
attemptId = 0,
2,
1,
Collections.emptyMap(),
1,
Array.emptyIntArray,
Array.emptyLongArray,
SerdeVersion.V1,
bytesWritten = 1000))
assert(isCancelShuffleInvoked)
}

override def afterAll(): Unit = {
logInfo("all test complete , stop celeborn mini cluster")
shutdownMiniCluster()
Expand Down
Loading