Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|$mapDataClass ${ev.value} = null;
""".stripMargin

val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
| if (${m.isNull}) {
| $hasNullName = true;
| }
|}
""".stripMargin
val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
case ((m, true), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| if (!${m.isNull}) {
| $argsName[$i] = ${m.value};
| } else {
| $hasNullName = true;
| }
|}
""".stripMargin
case ((m, false), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
|}
""".stripMargin
}

val codes = ctx.splitExpressionsWithCurrentInputs(
Expand All @@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")

val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
genCodeForPrimitiveArrays(ctx, keyType, false)
} else {
genCodeForNonPrimitiveArrays(ctx, keyType)
}

val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}
val valueConcat =
if (valueType.sameType(keyType) &&
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
keyConcat
} else if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}

val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")
Expand All @@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
| }
| $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
| $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
| (int) $numElementsName);
| $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
| $arrayDataClass $finValsName = $valueConcat($valArgsName,
| (int) $numElementsName);
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|}
Expand Down Expand Up @@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
setterCode1
}

s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $setterCode
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $setterCode
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin

ctx.addNewFunction(concat, concatDef)
}

private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
Expand All @@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")

s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
| Object[] $arrayData = new Object[$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| Object[] $arrayData = new Object[$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
|}
""".stripMargin

ctx.addNewFunction(concat, concatDef)
}

override def prettyName: String = "map_concat"
Expand Down Expand Up @@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")
val hasNull = ctx.freshName("hasNull")

val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
case ((eval, true), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| if (!${eval.isNull}) {
| $args[$index] = ${eval.value};
| } else {
| $hasNull = true;
| }
|}
""".stripMargin
case ((eval, false), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| $args[$index] = ${eval.value};
|}
""".stripMargin
}

val (concatenator, initCode) = dataType match {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
returnType = "boolean",
makeSplitFunction = body =>
s"""
|$body
|return $hasNull;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n")
)

val (concat, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, containsNull) =>
val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType, containsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
(concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
ev.copy(code"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")

ev.copy(code =
code"""
|boolean $hasNull = false;
|$initCode
|$codes
|$javaType ${ev.value} = null;
|if (!$hasNull) {
| ${ev.value} = $concat($args);
|}
|boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}

private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
Expand All @@ -2322,49 +2369,55 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
(code, numElements)
}

private def nullArgumentProtection() : String = {
if (nullable) {
s"""
|for (int z = 0; z < ${children.length}; z++) {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}

private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
private def genCodeForPrimitiveArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
| );
| }
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
val setterCode =
s"""
|$arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
|);
""".stripMargin

val nullSafeSetterCode = if (checkForNull) {
s"""
|if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
|} else {
| $setterCode
|}
""".stripMargin
} else {
setterCode
}

val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $nullSafeSetterCode
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin

ctx.addNewFunction(concat, concatDef)
}

private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
Expand All @@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
|}
""".stripMargin

ctx.addNewFunction(concat, concatDef)
}

override def toString: String = s"concat(${children.mkString(", ")})"
Expand Down
Loading