diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c341943187820..ee65073148c6d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -125,13 +125,34 @@ case class ConcatWs(children: Seq[Expression]) if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.genCode(ctx)) - - val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String) null : ${eval.value}" - }.mkString(", ") - - ev.copy(evals.map(_.code).mkString("\n") + s""" - UTF8String ${ev.value} = UTF8String.concatWs($inputs); + val separator = evals.head + val strings = evals.tail + val numArgs = strings.length + val args = ctx.freshName("args") + + val inputs = strings.zipWithIndex.map { case (eval, index) => + if (eval.isNull != "true") { + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } else { + "" + } + } + val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + ctx.splitExpressions(inputs, "valueConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) + } else { + inputs.mkString("\n") + } + ev.copy(s""" + UTF8String[] $args = new UTF8String[$numArgs]; + ${separator.code} + $codes + UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args); boolean ${ev.isNull} = ${ev.value} == null; """) } else { @@ -144,32 +165,63 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};") + if (eval.isNull == "true") { + "" + } else { + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + }) case _: ArrayType => val size = ctx.freshName("n") - (s""" - if (!${eval.isNull}) { - $varargNum += ${eval.value}.numElements(); - } - """, - s""" - if (!${eval.isNull}) { - final int $size = ${eval.value}.numElements(); - for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; - } + if (eval.isNull == "true") { + ("", "") + } else { + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.value}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.value}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + } + } + """) } - """) } }.unzip - ev.copy(evals.map(_.code).mkString("\n") + - s""" + val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) + val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + "int", + { body => + s""" + int $varargNum = 0; + $body + return $varargNum; + """ + }, + _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) + val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs", + ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + "int", + { body => + s""" + $body + return $idxInVararg; + """ + }, + _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + ev.copy( + s""" + $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxInVararg = 0; - ${varargCount.mkString("\n")} + $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; - ${varargBuild.mkString("\n")} + $varargBuilds UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); boolean ${ev.isNull} = ${ev.value} == null; """) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 18ef4bc37c2b5..89d1e61ad72dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -74,6 +74,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") { + val N = 5000 + val sepExpr = Literal.create("#", StringType) + val strings1 = (1 to N).map(x => s"s$x") + val inputsExpr1 = strings1.map(Literal.create(_, StringType)) + checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow) + + val strings2 = (1 to N).map(x => Seq(s"s$x")) + val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType))) + checkEvaluation( + ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow) + } + test("elt") { def testElt(result: String, n: java.lang.Integer, args: String*): Unit = { checkEvaluation(