Skip to content

Commit a9a4a8b

Browse files
committed
SPARK-17503: Fix memory leak in Memory store when unable to cache the whole RDD
1 parent 72eec70 commit a9a4a8b

File tree

2 files changed

+87
-14
lines changed

2 files changed

+87
-14
lines changed

core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -663,31 +663,43 @@ private[spark] class MemoryStore(
663663
private[storage] class PartiallyUnrolledIterator[T](
664664
memoryStore: MemoryStore,
665665
unrollMemory: Long,
666-
unrolled: Iterator[T],
666+
private[this] var unrolled: Iterator[T],
667667
rest: Iterator[T])
668668
extends Iterator[T] {
669669

670-
private[this] var unrolledIteratorIsConsumed: Boolean = false
671-
private[this] var iter: Iterator[T] = {
672-
val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, {
673-
unrolledIteratorIsConsumed = true
674-
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
675-
})
676-
completionIterator ++ rest
670+
private def releaseUnrollMemory(): Unit = {
671+
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
672+
// SPARK-17503: Garbage collects the unrolling memory before the life end of
673+
// PartiallyUnrolledIterator.
674+
unrolled = null
677675
}
678676

679-
override def hasNext: Boolean = iter.hasNext
680-
override def next(): T = iter.next()
677+
override def hasNext: Boolean = {
678+
if (unrolled == null) {
679+
rest.hasNext
680+
} else if (!unrolled.hasNext) {
681+
releaseUnrollMemory()
682+
rest.hasNext
683+
} else {
684+
true
685+
}
686+
}
687+
688+
override def next(): T = {
689+
if (unrolled == null) {
690+
rest.next()
691+
} else {
692+
unrolled.next()
693+
}
694+
}
681695

682696
/**
683697
* Called to dispose of this iterator and free its memory.
684698
*/
685699
def close(): Unit = {
686-
if (!unrolledIteratorIsConsumed) {
687-
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
688-
unrolledIteratorIsConsumed = true
700+
if (unrolled != null) {
701+
releaseUnrollMemory()
689702
}
690-
iter = null
691703
}
692704
}
693705

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.storage
19+
20+
import org.mockito.Matchers
21+
import org.mockito.Mockito._
22+
import org.scalatest.mock.MockitoSugar
23+
24+
import org.apache.spark.SparkFunSuite
25+
import org.apache.spark.memory.MemoryMode.ON_HEAP
26+
import org.apache.spark.storage.memory.{MemoryStore, PartiallyUnrolledIterator}
27+
28+
class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
29+
test("join two iterators") {
30+
val unrollSize = 1000
31+
val unroll = (0 until unrollSize).iterator
32+
val restSize = 500
33+
val rest = (unrollSize until restSize + unrollSize).iterator
34+
35+
val memoryStore = mock[MemoryStore]
36+
val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
37+
38+
// Firstly iterate over unrolling memory iterator
39+
(0 until unrollSize).foreach { value =>
40+
assert(joinIterator.hasNext)
41+
assert(joinIterator.hasNext)
42+
assert(joinIterator.next() == value)
43+
}
44+
45+
joinIterator.hasNext
46+
joinIterator.hasNext
47+
verify(memoryStore, times(1))
48+
.releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong))
49+
50+
// Secondly, iterate over rest iterator
51+
(unrollSize until unrollSize + restSize).foreach { value =>
52+
assert(joinIterator.hasNext)
53+
assert(joinIterator.hasNext)
54+
assert(joinIterator.next() == value)
55+
}
56+
57+
joinIterator.close()
58+
// MemoryMode.releaseUnrollMemoryForThisTask is called only once
59+
verifyNoMoreInteractions(memoryStore)
60+
}
61+
}

0 commit comments

Comments
 (0)