diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ddcf61b882d3..b28e611d4532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -124,6 +124,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + def logicalPlan: Option[LogicalPlan] = { + getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + } + /** * Returns the result of this query as an RDD[InternalRow] by delegating to `doExecute` after * preparations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 3cd02b984d33..1e38cac5f092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -56,7 +56,7 @@ private[execution] object SparkPlanInfo { case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.stats) } // dump the file scan metadata (e.g file path) to event log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fd4a7897c7ad..0832fada7a4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -46,8 +46,13 @@ case class BroadcastHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin with CodegenSupport { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override lazy val metrics = { + Map("numOutputRows" -> + SQLMetrics.createMetric( + sparkContext, + "number of output rows", + logicalPlan.map(_.stats.rowCount.map(_.toLong).getOrElse(-1L)).getOrElse(-1L))) + } override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala index adb81519dbc8..a2caad69ca9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi class SQLMetricInfo( val name: String, val accumulatorId: Long, - val metricType: String) + val metricType: String, + val stats: Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 19809b07508d..65d68f8d51cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ -class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { +class SQLMetric(val metricType: String, initValue: Long = 0L, val stats: Long = -1L) extends + AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 @@ -42,7 +43,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato private var _zeroValue = initValue override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) + val newAcc = new SQLMetric(metricType, _value, stats) newAcc._zeroValue = initValue newAcc } @@ -96,8 +97,8 @@ object SQLMetrics { metric.set((v * baseForAvgMetric).toLong) } - def createMetric(sc: SparkContext, name: String): SQLMetric = { - val acc = new SQLMetric(SUM_METRIC) + def createMetric(sc: SparkContext, name: String, stats: Long = -1): SQLMetric = { + val acc = new SQLMetric(SUM_METRIC, stats = stats) acc.register(sc, name = Some(name), countFailedValues = false) acc } @@ -193,6 +194,14 @@ object SQLMetrics { } } + def stringStats(value: Long): String = { + if (value < 0) { + "" + } else { + s" est: ${stringValue(SUM_METRIC, Seq(value))}" + } + } + /** * Updates metrics based on the driver side value. This is useful for certain metrics that * are only updated on the driver, e.g. subquery execution time, or number of files. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e496de1b05e4..76c59c1f0691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -180,17 +180,20 @@ class SQLAppStatusListener( } private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { - val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val metricMap = exec.metrics.map { m => (m.accumulatorId, m) }.toMap val metrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } .flatMap(_.taskMetrics.values().asScala) .flatMap { metrics => metrics.ids.zip(metrics.values) } val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricTypes.contains(id) } + .filter { case (id, _) => metricMap.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + val metric = metricMap(id) + val value = SQLMetrics.stringValue(metric.metricType, values.map(_._2)) + val stats = SQLMetrics.stringStats(metric.stats) + id -> (value + stats) } // Check the execution again for whether the aggregated metrics data has been calculated. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 241001a857c8..c3083d8e41bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -142,4 +142,5 @@ class SparkPlanGraphNodeWrapper( case class SQLPlanMetric( name: String, accumulatorId: Long, - metricType: String) + metricType: String, + stats : Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index b864ad1c7108..26d6b74381a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -80,7 +80,7 @@ object SparkPlanGraph { planInfo.nodeName match { case "WholeStageCodegen" => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val cluster = new SparkPlanGraphCluster( @@ -114,7 +114,7 @@ object SparkPlanGraph { edges += SparkPlanGraphEdge(node.id, parent.id) case name => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 81cac9d95b7a..8799977a76ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1430,4 +1430,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalogStats.rowCount.isEmpty) } } + + test("statistics for broadcastHashJoin numOutputRows statistic") { + withTempView("t1", "t2") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + sql("CREATE TABLE t1 (key INT, a2 STRING, a3 DOUBLE)") + sql("INSERT INTO TABLE t1 SELECT 1, 'a', 10.0") + sql("INSERT INTO TABLE t1 SELECT 1, 'b', null") + sql("ANALYZE TABLE t1 COMPUTE STATISTICS FOR ALL COLUMNS") + + sql("CREATE TABLE t2 (key INT, b2 STRING, b3 DOUBLE)") + sql("INSERT INTO TABLE t2 SELECT 1, 'a', 10.0") + sql("ANALYZE TABLE t2 COMPUTE STATISTICS FOR ALL COLUMNS") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key") + assert(df.queryExecution.sparkPlan.isInstanceOf[BroadcastHashJoinExec]) + assert(df.queryExecution.sparkPlan.metrics("numOutputRows").stats == 2) + } + } + } }