Skip to content

Commit 6cd1ccc

Browse files
Ngone51dongjoon-hyun
authored andcommitted
[SPARK-48394][CORE] Cleanup mapIdToMapIndex on mapoutput unregister
### What changes were proposed in this pull request? This PR cleans up `mapIdToMapIndex` when the corresponding mapstatus is unregistered in three places: * `removeMapOutput` * `removeOutputsByFilter` * `addMapOutput` (old mapstatus overwritten) ### Why are the changes needed? There is only one valid mapstatus for the same `mapIndex` at the same time in Spark. `mapIdToMapIndex` should also follows the same rule to avoid chaos. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46706 from Ngone51/SPARK-43043-followup. Lead-authored-by: Yi Wu <yi.wu@databricks.com> Co-authored-by: wuyi <yi.wu@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 1a536f0 commit 6cd1ccc

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
4444
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
4545
import org.apache.spark.util._
4646
import org.apache.spark.util.ArrayImplicits._
47-
import org.apache.spark.util.collection.OpenHashMap
4847
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
4948

5049
/**
@@ -153,17 +152,22 @@ private class ShuffleStatus(
153152
/**
154153
* Mapping from a mapId to the mapIndex, this is required to reduce the searching overhead within
155154
* the function updateMapOutput(mapId, bmAddress).
155+
*
156+
* Exposed for testing.
156157
*/
157-
private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]()
158+
private[spark] val mapIdToMapIndex = new HashMap[Long, Int]()
158159

159160
/**
160161
* Register a map output. If there is already a registered location for the map output then it
161162
* will be replaced by the new location.
162163
*/
163164
def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
164-
if (mapStatuses(mapIndex) == null) {
165+
val currentMapStatus = mapStatuses(mapIndex)
166+
if (currentMapStatus == null) {
165167
_numAvailableMapOutputs += 1
166168
invalidateSerializedMapOutputStatusCache()
169+
} else {
170+
mapIdToMapIndex.remove(currentMapStatus.mapId)
167171
}
168172
mapStatuses(mapIndex) = status
169173
mapIdToMapIndex(status.mapId) = mapIndex
@@ -193,8 +197,8 @@ private class ShuffleStatus(
193197
mapStatus.updateLocation(bmAddress)
194198
invalidateSerializedMapOutputStatusCache()
195199
case None =>
196-
if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) {
197-
val index = mapIndex.get
200+
val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId)
201+
if (index >= 0 && mapStatuses(index) == null) {
198202
val mapStatus = mapStatusesDeleted(index)
199203
mapStatus.updateLocation(bmAddress)
200204
mapStatuses(index) = mapStatus
@@ -222,9 +226,11 @@ private class ShuffleStatus(
222226
*/
223227
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
224228
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
225-
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
229+
val currentMapStatus = mapStatuses(mapIndex)
230+
if (currentMapStatus != null && currentMapStatus.location == bmAddress) {
226231
_numAvailableMapOutputs -= 1
227-
mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
232+
mapIdToMapIndex.remove(currentMapStatus.mapId)
233+
mapStatusesDeleted(mapIndex) = currentMapStatus
228234
mapStatuses(mapIndex) = null
229235
invalidateSerializedMapOutputStatusCache()
230236
}
@@ -290,9 +296,11 @@ private class ShuffleStatus(
290296
*/
291297
def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
292298
for (mapIndex <- mapStatuses.indices) {
293-
if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
299+
val currentMapStatus = mapStatuses(mapIndex)
300+
if (currentMapStatus != null && f(currentMapStatus.location)) {
294301
_numAvailableMapOutputs -= 1
295-
mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex)
302+
mapIdToMapIndex.remove(currentMapStatus.mapId)
303+
mapStatusesDeleted(mapIndex) = currentMapStatus
296304
mapStatuses(mapIndex) = null
297305
invalidateSerializedMapOutputStatusCache()
298306
}

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,4 +1110,59 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
11101110
rpcEnv.shutdown()
11111111
}
11121112
}
1113+
1114+
test(
1115+
"SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after removeOutputsByFilter"
1116+
) {
1117+
val rpcEnv = createRpcEnv("test")
1118+
val tracker = newTrackerMaster()
1119+
try {
1120+
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
1121+
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
1122+
tracker.registerShuffle(0, 1, 1)
1123+
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
1124+
Array(2L), 0))
1125+
tracker.removeOutputsOnHost("hostA")
1126+
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0)
1127+
} finally {
1128+
tracker.stop()
1129+
rpcEnv.shutdown()
1130+
}
1131+
}
1132+
1133+
test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after unregisterMapOutput") {
1134+
val rpcEnv = createRpcEnv("test")
1135+
val tracker = newTrackerMaster()
1136+
try {
1137+
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
1138+
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
1139+
tracker.registerShuffle(0, 1, 1)
1140+
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
1141+
Array(2L), 0))
1142+
tracker.unregisterMapOutput(0, 0, BlockManagerId("exec-1", "hostA", 1000))
1143+
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 0)
1144+
} finally {
1145+
tracker.stop()
1146+
rpcEnv.shutdown()
1147+
}
1148+
}
1149+
1150+
test("SPARK-48394: mapIdToMapIndex should cleanup unused mapIndexes after registerMapOutput") {
1151+
val rpcEnv = createRpcEnv("test")
1152+
val tracker = newTrackerMaster()
1153+
try {
1154+
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
1155+
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
1156+
tracker.registerShuffle(0, 1, 1)
1157+
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-1", "hostA", 1000),
1158+
Array(2L), 0))
1159+
// Another task also finished working on partition 0.
1160+
tracker.registerMapOutput(0, 0, MapStatus(BlockManagerId("exec-2", "hostB", 1000),
1161+
Array(2L), 1))
1162+
assert(tracker.shuffleStatuses(0).mapIdToMapIndex.filter(_._2 == 0).size == 1)
1163+
} finally {
1164+
tracker.stop()
1165+
rpcEnv.shutdown()
1166+
}
1167+
}
11131168
}

0 commit comments

Comments
 (0)