diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5c9e604a8d29..cb94f31e5a3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.io.ByteArrayInputStream +import java.lang.Character._ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ @@ -1103,6 +1104,29 @@ class CodegenContext { } } +object CodegenContext { + + private val javaKeywords = Set( + "abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", + "continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", + "for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native", + "new", "null", "package", "private", "protected", "public", "return", "short", "static", + "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", + "try", "void", "volatile", "while" + ) + + /** + * Returns true if the given `str` is a valid java identifier. + */ + def isJavaIdentifier(str: String): Boolean = str match { + case null | "" => + false + case _ => + !javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) && + (1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i))) + } +} + /** * A wrapper for generated class, defines a `generate` method so that we can pass extra objects * into generated class. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1121444cc938..cae5640b2f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -607,6 +607,17 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_PARAM_NUM_IN_JAVA_METHOD = + buildConf("spark.sql.codegen.maxParamNumInJavaMethod") + .internal() + .doc("The maximum number of parameters in codegened Java functions. When a function " + + "exceeds this threshold, the code generator gives up splitting the function code. " + + "This default value is 127 because the maximum length of parameters in non-static Java " + + "methods is 254 and a parameter of type long or double contributes " + + "two units to the length.") + .intConf + .createWithDefault(127) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + @@ -1156,6 +1167,8 @@ class SQLConf extends Serializable with Logging { def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def maxParamNumInJavaMethod: Int = getConf(MAX_PARAM_NUM_IN_JAVA_METHOD) + def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 40bf29bb3b57..f67eefab4d1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -394,4 +394,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Map("add" -> Literal(1))).genCode(ctx) assert(ctx.mutableStates.isEmpty) } + + test("SPARK-21870 check if CodegenContext.isJavaIdentifier works correctly") { + import CodegenContext.isJavaIdentifier + // positive cases + assert(isJavaIdentifier("agg_value")) + assert(isJavaIdentifier("agg_value1")) + assert(isJavaIdentifier("bhj_value4")) + assert(isJavaIdentifier("smj_value6")) + assert(isJavaIdentifier("rdd_value7")) + assert(isJavaIdentifier("scan_isNull")) + assert(isJavaIdentifier("test")) + // negative cases + assert(!isJavaIdentifier("true")) + assert(!isJavaIdentifier("false")) + assert(!isJavaIdentifier("390239")) + assert(!isJavaIdentifier(""""literal"""")) + assert(!isJavaIdentifier(""""double"""")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e7..83b0f807f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD @@ -257,6 +259,78 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + context: CodegenContext, + aggregateExpression: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { + // `argSet` collects all the pairs of variable names and their types, the first in the pair is + // a type name and the second is a variable name. + val argSet = mutable.Set[(String, String)]() + val stack = mutable.Stack[Expression](aggregateExpression) + while (stack.nonEmpty) { + stack.pop() match { + case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { + argSet += ((context.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { + argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) + case ref: BoundReference + if context.currentVars != null && context.currentVars(ref.ordinal) != null => + val value = context.currentVars(ref.ordinal).value + val isNull = context.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { + argSet += ((context.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { + argSet += (("boolean", isNull)) + } + case _: BoundReference => + argSet += (("InternalRow", context.INPUT_ROW)) + case e => + stack.pushAll(e.children) + } + } + + argSet.toSet + } + + // Splits aggregate code into small functions because JVMs does not compile too long functions + private def splitAggregateExpressions( + context: CodegenContext, + aggregateExpressions: Seq[Expression], + codes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { + aggregateExpressions.zipWithIndex.map { case (aggExpr, i) => + val args = (getInputVariableReferences(context, aggExpr, subExprs) ++ otherArgs).toSeq + + // This method gives up splitting the code if the parameter length goes over + // `maxParamNumInJavaMethod`. + if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) { + val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}") + val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ") + val doAggValFuncName = context.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${codes(i)} + | } + """.stripMargin) + + val inputVariables = args.map(_._2).mkString(", ") + s"$doAggValFuncName($inputVariables);" + } else { + codes(i) + } + } + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -269,28 +343,53 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input + + // We need to copy the aggregation buffer to local variables first because each aggregate + // function directly updates the buffer when it finishes. + val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => + val isNull = ctx.freshName("localBufIsNull") + val value = ctx.freshName("localBufValue") + val initLocalVars = s""" + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(initLocalVars, isNull, value) + } + + val initLocalBufVar = evaluateVariables(localBufVars) + + ctx.currentVars = localBufVars ++ input val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) => s""" + | // evaluate aggregate function + | ${ev.code} + | // update aggregation buffer | ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).value} = ${ev.value}; """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states) + s""" | // do aggregate + | // copy aggregation buffer to the local + | $initLocalBufVar | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} """.stripMargin } @@ -825,52 +924,92 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val updateRowInRegularHashMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + otherArgs = Seq(("InternalRow", unsafeRowBuffer))) + s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} + | // do aggregate + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} """.stripMargin } val updateRowInHashMap: String = { if (isFastHashMapEnabled) { - ctx.INPUT_ROW = fastRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + val updateColumnCode = ctx.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update fast row + | $updateColumnCode + """.stripMargin } + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + otherArgs = Seq(("InternalRow", fastRowBuffer))) + // If fast hash map is on, we first generate code to update row in fast hash map, if the // previous loop up hit fast hash map. Otherwise, update row in regular hash map. s""" |if ($fastRowBuffer != null) { + | // copy aggregation row buffer to the local + | $initLocalRowBuffer | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} |} else { | $updateRowInRegularHashMap |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c4..67fdb49e9101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -211,11 +211,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) - val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) - assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val (_, smallCodeSize) = CodeGenerator.compile(codeWithShortFunctions) val codeWithLongFunctions = genGroupByCode(20) - val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) - assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val (_, largeCodeSize) = CodeGenerator.compile(codeWithLongFunctions) + // Just checking if long functions have the large value of max code size + assert(largeCodeSize > smallCodeSize) } test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { @@ -236,4 +236,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21870 check the case where the number of parameters goes over the limit") { + withSQLConf("spark.sql.codegen.maxParamNumInJavaMethod" -> "2") { + sql("CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES (1, 1, 1) AS t(a, b, c)") + val df = sql("SELECT SUM(a + b + c) AS sum FROM t") + assert(df.collect === Seq(Row(3))) + } + } }