Skip to content

Commit 2f304f0

Browse files
committed
Use task completion listener instead.
1 parent 93c93da commit 2f304f0

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
234234
row.writeToStream(out, buffer)
235235
count += 1
236236
}
237-
// If iterator has more elements, we should consume them all. Otherwise under wholestage
238-
// codegen, as we release resources after consuming all elements (e.g., HashAggregate), it
239-
// will cause problems such as memory leak.
240-
while (iter.hasNext) {
241-
iter.next()
242-
}
243237
out.writeInt(-1)
244238
out.flush()
245239
out.close()

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,24 +295,35 @@ case class HashAggregateExec(
295295
private var hashMapTerm: String = _
296296
private var sorterTerm: String = _
297297

298+
// Becasue Dataset.show/take methods will end of iteraton before reaching the end of all rows,
299+
// we may not release resources then and cause memory leak. So we need to hold the reference
300+
// of the hash map if it is created and release the resources after task completion.
301+
private var hashMapToRelease: UnsafeFixedWidthAggregationMap = _
302+
298303
/**
299304
* This is called by generated Java class, should be public.
300305
*/
301306
def createHashMap(): UnsafeFixedWidthAggregationMap = {
302307
// create initialized aggregate buffer
303308
val initExpr = declFunctions.flatMap(f => f.initialValues)
304309
val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
310+
val context = TaskContext.get()
305311

306312
// create hashMap
307-
new UnsafeFixedWidthAggregationMap(
313+
hashMapToRelease = new UnsafeFixedWidthAggregationMap(
308314
initialBuffer,
309315
bufferSchema,
310316
groupingKeySchema,
311-
TaskContext.get().taskMemoryManager(),
317+
context.taskMemoryManager(),
312318
1024 * 16, // initial capacity
313-
TaskContext.get().taskMemoryManager().pageSizeBytes,
319+
context.taskMemoryManager().pageSizeBytes,
314320
false // disable tracking of performance metrics
315321
)
322+
323+
// Release the resources of the hash map when the end of task.
324+
context.addTaskCompletionListener(_ => hashMapToRelease.free())
325+
326+
hashMapToRelease
316327
}
317328

318329
def getTaskMemoryManager(): TaskMemoryManager = {

0 commit comments

Comments
 (0)