Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,34 @@ case class ConcatWs(children: Seq[Expression])
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.genCode(ctx))

val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
}.mkString(", ")

ev.copy(evals.map(_.code).mkString("\n") + s"""
UTF8String ${ev.value} = UTF8String.concatWs($inputs);
val separator = evals.head
val strings = evals.tail
val numArgs = strings.length
val args = ctx.freshName("args")

val inputs = strings.zipWithIndex.map { case (eval, index) =>
if (eval.isNull != "true") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk I was looking at build warnings and it notes that this compares a ExprValue and String and they will always not be equal. Should it be eval.isNull.code != "true" maybe?

s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
} else {
""
}
}
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(inputs, "valueConcatWs",
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
}
ev.copy(s"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
} else {
Expand All @@ -144,32 +165,63 @@ case class ConcatWs(children: Seq[Expression])
child.dataType match {
case StringType =>
("", // we count all the StringType arguments num at once below.
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};")
if (eval.isNull == "true") {
""
} else {
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
})
case _: ArrayType =>
val size = ctx.freshName("n")
(s"""
if (!${eval.isNull}) {
$varargNum += ${eval.value}.numElements();
}
""",
s"""
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
}
if (eval.isNull == "true") {
("", "")
} else {
(s"""
if (!${eval.isNull}) {
$varargNum += ${eval.value}.numElements();
}
""",
s"""
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
}
}
""")
}
""")
}
}.unzip

ev.copy(evals.map(_.code).mkString("\n") +
s"""
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs",
("InternalRow", ctx.INPUT_ROW) :: Nil,
"int",
{ body =>
s"""
int $varargNum = 0;
$body
return $varargNum;
"""
},
_.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we optimize for eval.isNull == "true" in varargCount and varargBuild? since you already did it for the all string cases.

("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
"int",
{ body =>
s"""
$body
return $idxInVararg;
"""
},
_.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
ev.copy(
s"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxInVararg = 0;
${varargCount.mkString("\n")}
$varargCounts
UTF8String[] $array = new UTF8String[$varargNum];
${varargBuild.mkString("\n")}
$varargBuilds
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
boolean ${ev.isNull} = ${ev.value} == null;
""")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") {
val N = 5000
val sepExpr = Literal.create("#", StringType)
val strings1 = (1 to N).map(x => s"s$x")
val inputsExpr1 = strings1.map(Literal.create(_, StringType))
checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)

val strings2 = (1 to N).map(x => Seq(s"s$x"))
val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
checkEvaluation(
ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
}

test("elt") {
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
checkEvaluation(
Expand Down