1717
1818package org .apache .spark .sql .util
1919
20- import org .apache .spark .SparkException
20+ import scala .collection .mutable .ArrayBuffer
21+
22+ import org .apache .spark ._
2123import org .apache .spark .sql .{functions , QueryTest }
2224import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , Project }
2325import org .apache .spark .sql .execution .QueryExecution
2426import org .apache .spark .sql .test .SharedSQLContext
2527
26- import scala .collection .mutable .ArrayBuffer
27-
2828class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
2929 import testImplicits ._
3030 import functions ._
@@ -81,7 +81,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
8181 assert(metrics(0 )._3.getMessage == e.getMessage)
8282 }
8383
84- test(" get metrics by callback" ) {
84+ test(" get numRows metrics by callback" ) {
8585 val metrics = ArrayBuffer .empty[Long ]
8686 val listener = new QueryExecutionListener {
8787 // Only test successful case here, so no need to implement `onFailure`
@@ -103,4 +103,43 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
103103 assert(metrics(1 ) == 1 )
104104 assert(metrics(2 ) == 2 )
105105 }
106+
107+ test(" get size metrics by callback" ) {
108+ val metrics = ArrayBuffer .empty[Long ]
109+ val listener = new QueryExecutionListener {
110+ // Only test successful case here, so no need to implement `onFailure`
111+ override def onFailure (funcName : String , qe : QueryExecution , exception : Exception ): Unit = {}
112+
113+ override def onSuccess (funcName : String , qe : QueryExecution , duration : Long ): Unit = {
114+ metrics += qe.executedPlan.longMetric(" dataSize" ).value.value
115+ val bottomAgg = qe.executedPlan.children(0 ).children(0 )
116+ metrics += bottomAgg.longMetric(" dataSize" ).value.value
117+ }
118+ }
119+ sqlContext.listenerManager.register(listener)
120+
121+ val sparkListener = new SaveInfoListener
122+ sqlContext.sparkContext.addSparkListener(sparkListener)
123+
124+ val df = (1 to 100 ).map(i => i -> i.toString).toDF(" i" , " j" )
125+ df.groupBy(" i" ).count().collect()
126+
127+ def getPeakExecutionMemory (stageId : Int ): Long = {
128+ val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables
129+ .filter(_._2.name == InternalAccumulator .PEAK_EXECUTION_MEMORY )
130+
131+ assert(peakMemoryAccumulator.size == 1 )
132+ peakMemoryAccumulator.head._2.value.toLong
133+ }
134+
135+ assert(sparkListener.getCompletedStageInfos.length == 2 )
136+ val bottomAggDataSize = getPeakExecutionMemory(0 )
137+ val topAggDataSize = getPeakExecutionMemory(1 )
138+
139+ // For this simple case, the peakExecutionMemory of a stage should be the data size of the
140+ // aggregate operator, as we only have one memory consuming operator per stage.
141+ assert(metrics.length == 2 )
142+ assert(metrics(0 ) == topAggDataSize)
143+ assert(metrics(1 ) == bottomAggDataSize)
144+ }
106145}
0 commit comments