diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 34998cbd61552..6272e612573c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -245,5 +245,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) // scalastyle:on println } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenToSeq(): Seq[(String, String)] = { + org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 0395c43ba2cbc..a717cbd4a7df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -50,7 +50,31 @@ package object debug { // scalastyle:on println } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param plan the query plan for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ def codegenString(plan: SparkPlan): String = { + val codegenSeq = codegenStringSeq(plan) + var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n" + for (((subtree, code), i) <- codegenSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n" + output += subtree + output += "\nGenerated code:\n" + output += s"${code}\n" + } + output + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param plan the query plan for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() plan transform { case s: WholeStageCodegenExec => @@ -58,15 +82,10 @@ package object debug { s case s => s } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" + codegenSubtrees.toSeq.map { subtree => + val (_, source) = subtree.doCodeGen() + (subtree.toString, CodeFormatter.format(source)) } - output } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4fc52c99fbeeb..adcaf2d76519f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -38,4 +38,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } + + test("debugCodegenStringSeq") { + val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall{ case (subtree, code) => + subtree.contains("Range") && code.contains("Object[]")}) + } }