From 282e7249c2ed35add40f54087423ca62732b6046 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 13 Mar 2018 21:35:45 +0100 Subject: [PATCH 01/18] [SPARK-23736][SQL] Implementation of the concat_arrays function concatenating multiple array columns into one. --- python/pyspark/sql/functions.py | 19 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/Expression.scala | 85 ++++++++++ .../expressions/collectionOperations.scala | 153 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 26 +++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 57 +++++++ 7 files changed, 348 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dff590983b4d..f08f3c8f8a57 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def concat_arrays(*cols): + """ + Collection function: Concatenates multiple arrays into one. + + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat_arrays(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + args = _to_seq(sc, cols, _to_java_column) + return Column(sc._jvm.functions.concat_arrays(args)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e..7c7671acb639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -408,6 +408,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ConcatArrays]("concat_arrays"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d7f9e38915dd..9dcf7164b61e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null save evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression +{ + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. + * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also + * override this. + */ + override def eval(input: InternalRow): Any = + { + val values = children.map(_.eval(input)) + if (values.contains(null)) null + else nullSafeEval(values) + } + + /** + * Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep + * the default nullability, they can override this method to save null-check code. If we need + * full control of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(inputs: Seq[Any]): Any = + sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval") + + /** + * Short hand for generating of null save evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts a sequence of variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { + nullSafeCodeGen(ctx, ev, values => { + s"${ev.value} = ${f(values)};" + }) + } + + /** + * Called by expressions to generate null safe evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function that accepts a sequence of non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { + val gens = children.map(_.genCode(ctx)) + val resultCode = f(gens.map(_.value)) + + if (nullable) { + val nullSafeEval = + (s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ /: children.zip(gens)) { + case (acc, (child, gen)) => + gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc) + } + + ev.copy(code = s""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = s""" + boolean ${ev.isNull} = false; + ${gens.map(_.code).mkString("\n")} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = "false") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e..24b732577933 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,8 +21,10 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods /** * Given an array or map, returns its size. Returns -1 if null. @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { + val arrayCheck = checkInputDataTypesAreArrays + if(arrayCheck.isFailure) arrayCheck + else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { + val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => + s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." + } + + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } + } + + override def dataType: ArrayType = + children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) + + + override protected def nullSafeEval(inputs: Seq[Any]): Any = { + val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) + new GenericArrayData(elements) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, arrays => { + val elementType = dataType.elementType + if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value) + } else { + genCodeForConcatOfComplexElements(ctx, arrays, ev.value) + } + }) + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + elements: Seq[String] + ) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = elements + .map(el => s"$variableName += $el.numElements();") + .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s) + (code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( + ctx: CodegenContext, + elementType: DataType, + elements: Seq[String], + arrayDataName: String + ): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + + val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + |${elementType.defaultSize} * $numElemName + |); + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val assignments = elements.map { el => + s""" + |for(int z = 0; z < $el.numElements(); z++) { + | if($el.isNullAt(z)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | $el.get$primitiveValueTypeName(z) + | ); + | } + | $counter++; + |} + """.stripMargin + }.mkString("\n") + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |int $counter = 0; + |$assignments + |$arrayDataName = $tempArrayDataName; + """.stripMargin + + } + + private def genCodeForConcatOfComplexElements( + ctx: CodegenContext, + elements: Seq[String], + arrayDataName: String + ): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + + val assignments = elements.map { el => + s""" + |for(int z = 0; z < $el.numElements(); z++) { + | $arrayName[$counter] = $el.array()[z]; + | $counter++; + |} + """.stripMargin + }.mkString("\n") + + s""" + |$numElemCode + |Object[] $arrayName = new Object[$numElemName]; + |int $counter = 0; + |$assignments + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "concat_arrays" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a2..78be7a20a184 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,30 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Concat Arrays") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq(4, null, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a4 = Literal.create(Seq("e", null), ArrayType(StringType)) + val a5 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val an = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ConcatArrays(Seq(a0)), Seq(1, 2, 3)) + checkEvaluation(ConcatArrays(Seq(a0, a2)), Seq(1, 2, 3, 4, null, 6)) + checkEvaluation(ConcatArrays(Seq(a0, a1, a2)), Seq(1, 2, 3, 4, null, 6)) + + checkEvaluation(ConcatArrays(Seq(a3, a4)), Seq("a", "b", "c", "e", null)) + checkEvaluation(ConcatArrays(Seq(a3, a4, a5)), Seq("a", "b", "c", "e", null)) + checkEvaluation(ConcatArrays(Seq(a5)), Seq()) + + checkEvaluation(ConcatArrays(Seq()), Seq()) + checkEvaluation(ConcatArrays(Seq(a0, a0)), Seq(1, 2, 3, 1, 2, 3)) + + checkEvaluation(ConcatArrays(Seq(an)), null) + checkEvaluation(ConcatArrays(Seq(a3, an)), null) + checkEvaluation(ConcatArrays(Seq(an, a3)), null) + checkEvaluation(ConcatArrays(Seq(a3, an, a4)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a899634..2ca9ffbc7703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3046,6 +3046,14 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Merges multiple arrays into one by putting elements from the specific array after elements + * from the previous array. If any of the arrays is null, null is returned. + * @group collection_funcs + * @since 2.4.0 + */ + def concat_arrays(columns: Column*): Column = withExpr { ConcatArrays(columns.map(_.expr)) } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f45..94d1f2fa1926 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("concat arrays function") { + val nint : Int = null.asInstanceOf[Int] + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + (Seq(1), Seq(2, 3, 4), Seq(5, 6), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, nint), Seq.empty[Int], Seq(2), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + // Simple test cases + checkAnswer( + df.select(concat_arrays($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, nint, 2))) + ) + checkAnswer( + df.selectExpr("concat_arrays(i1, i2, i3)"), + Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, nint, 2))) + ) + checkAnswer( + df.select(concat_arrays($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat_arrays(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat_arrays($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat_arrays($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat_arrays($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat_arrays($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.select(concat_arrays($"i1", $"s1")) + } + intercept[AnalysisException] { + df.select(concat_arrays(lit("a"), lit("b"))) + } + intercept[AnalysisException] { + df.selectExpr("concat_arrays(i1, i2, null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From aa5a0898facaa0421326ebfd0fe956cffaffaafd Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 12:55:45 +0200 Subject: [PATCH 02/18] [SPARK-23736][SQL] Code style fixes. --- .../sql/catalyst/expressions/Expression.scala | 18 +++++---- .../expressions/collectionOperations.scala | 39 ++++++++++--------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9dcf7164b61e..8abc9e0fb463 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -701,7 +701,7 @@ abstract class TernaryExpression extends Expression { trait UserDefinedExpression /** - * The trait covers logic for performing null save evaluation and code generation. + * The trait covers logic for performing null safe evaluation and code generation. */ trait NullSafeEvaluation extends Expression { @@ -716,9 +716,12 @@ trait NullSafeEvaluation extends Expression */ override def eval(input: InternalRow): Any = { - val values = children.map(_.eval(input)) - if (values.contains(null)) null - else nullSafeEval(values) + val values = children.toStream.map(_.eval(input)) + if (values.contains(null)) { + null + } else { + nullSafeEval(values) + } } /** @@ -761,12 +764,11 @@ trait NullSafeEvaluation extends Expression val resultCode = f(gens.map(_.value)) if (nullable) { - val nullSafeEval = - (s""" + val nullSafeEval = children.zip(gens).foldRight(s""" ${ev.isNull} = false; // resultCode could change nullability. $resultCode - """ /: children.zip(gens)) { - case (acc, (child, gen)) => + """) { + case ((child, gen), acc) => gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 24b732577933..a97656e62ab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -304,8 +304,11 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS override def checkInputDataTypes(): TypeCheckResult = { val arrayCheck = checkInputDataTypesAreArrays - if(arrayCheck.isFailure) arrayCheck - else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + if(arrayCheck.isFailure) { + arrayCheck + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } } private def checkInputDataTypesAreArrays(): TypeCheckResult = @@ -352,7 +355,7 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS val variableName = ctx.freshName("numElements") val code = elements .map(el => s"$variableName += $el.numElements();") - .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s) + .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s) (code, variableName) } @@ -372,7 +375,7 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS val unsafeArraySizeInBytes = s""" |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( - |${elementType.defaultSize} * $numElemName + | ${elementType.defaultSize} * $numElemName |); """.stripMargin val baseOffset = Platform.BYTE_ARRAY_OFFSET @@ -380,16 +383,16 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elements.map { el => s""" - |for(int z = 0; z < $el.numElements(); z++) { - | if($el.isNullAt(z)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | $el.get$primitiveValueTypeName(z) - | ); - | } - | $counter++; + |for (int z = 0; z < $el.numElements(); z++) { + | if ($el.isNullAt(z)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | $el.get$primitiveValueTypeName(z) + | ); + | } + | $counter++; |} """.stripMargin }.mkString("\n") @@ -404,7 +407,7 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS |int $counter = 0; |$assignments |$arrayDataName = $tempArrayDataName; - """.stripMargin + """.stripMargin } @@ -420,11 +423,11 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS val assignments = elements.map { el => s""" - |for(int z = 0; z < $el.numElements(); z++) { + |for (int z = 0; z < $el.numElements(); z++) { | $arrayName[$counter] = $el.array()[z]; | $counter++; |} - """.stripMargin + """.stripMargin }.mkString("\n") s""" @@ -433,7 +436,7 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS |int $counter = 0; |$assignments |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin + """.stripMargin } override def prettyName: String = "concat_arrays" From 90d3ab717a22812b77dbcfc1f692139cf4d1fcd9 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 13:15:59 +0200 Subject: [PATCH 03/18] [SPARK-23736][SQL] Improving the description of the ConcatArrays expression. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a97656e62ab7..25dd38d28e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -294,12 +294,13 @@ case class ArrayContains(left: Expression, right: Expression) * Concatenates multiple arrays into one. */ @ExpressionDescription( - usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays of the same type into one.", examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); [1,2,3,4,5,6] - """) + """, + since = "2.4.0") case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { override def checkInputDataTypes(): TypeCheckResult = { From bb46c3d3d3e18a9e05ddb6fe6efda3c25c2711a4 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 19:38:52 +0200 Subject: [PATCH 04/18] [SPARK-23736][SQL] Merging concat and concat_arrays into one function. --- python/pyspark/sql/functions.py | 35 ++++++------------- .../sql/catalyst/analysis/Analyzer.scala | 2 ++ .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../sql/catalyst/analysis/unresolved.scala | 12 +++++++ .../expressions/collectionOperations.scala | 17 ++++++++- .../org/apache/spark/sql/functions.scala | 22 ++++-------- .../spark/sql/DataFrameFunctionsSuite.scala | 32 ++++++++--------- 7 files changed, 62 insertions(+), 61 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f08f3c8f8a57..632afb3d7f4b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1414,21 +1414,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1834,23 +1819,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) -@since(2.4) -def concat_arrays(*cols): +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): """ - Collection function: Concatenates multiple arrays into one. + Concatenates multiple input columns together into a single column. + The function works with strings, binary columns and arrays of the same time. - :param cols: list of column names (string) or list of :class:`Column` expressions that have - the same data type. + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) - >>> df.select(concat_arrays(df.a, df.b, df.c).alias("arr")).collect() + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] """ sc = SparkContext._active_spark_context - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - args = _to_seq(sc, cols, _to_java_column) - return Column(sc._jvm.functions.concat_arrays(args)) + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7848f88bda1c..47cca93a00e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -825,6 +825,8 @@ class Analyzer( result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) + case UnresolvedConcat(children) if children.forall(_.resolved) => + ResolveConcat(children) case _ => e.mapChildren(resolve(_, q)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7c7671acb639..c7973521a310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -408,7 +407,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[ConcatArrays]("concat_arrays"), + expression[UnresolvedConcat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff..8cc5168e7e2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -480,3 +480,15 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * Concatenates multiple columns of the same type into one. + * @param children Could be string, binary or array expressions + */ +case class UnresolvedConcat(children: Seq[Expression]) extends Expression + with Unevaluable { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} \ No newline at end of file diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 25dd38d28e17..bea72b39ecb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -290,6 +290,21 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Replaces [[org.apache.spark.sql.catalyst.analysis.UnresolvedConcat UnresolvedConcat]]s + * with concrete concate expressions. + */ +object ResolveConcat +{ + def apply(children: Seq[Expression]): Expression = { + if (children.nonEmpty && ArrayType.acceptsType(children(0).dataType)) { + ConcatArrays(children) + } else { + Concat(children) + } + } +} + /** * Concatenates multiple arrays into one. */ @@ -440,5 +455,5 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS """.stripMargin } - override def prettyName: String = "concat_arrays" + override def prettyName: String = "concat" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2ca9ffbc7703..5938c2d5bc9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedConcat, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -2228,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -3047,12 +3037,14 @@ object functions { } /** - * Merges multiple arrays into one by putting elements from the specific array after elements - * from the previous array. If any of the arrays is null, null is returned. + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary columns and arrays of the same time. + * * @group collection_funcs - * @since 2.4.0 + * @since 1.5.0 */ - def concat_arrays(columns: Column*): Column = withExpr { ConcatArrays(columns.map(_.expr)) } + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { UnresolvedConcat(exprs.map(_.expr)) } /** * Creates a new row for each element in the given array or map column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 94d1f2fa1926..a011fec74522 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,60 +413,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("concat arrays function") { - val nint : Int = null.asInstanceOf[Int] + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( (Seq(1), Seq(2, 3, 4), Seq(5, 6), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), - (Seq(1, nint), Seq.empty[Int], Seq(2), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + (Seq(1, 0), Seq.empty[Int], Seq(2), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") // Simple test cases checkAnswer( - df.select(concat_arrays($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, nint, 2))) + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 0, 2))) ) checkAnswer( - df.selectExpr("concat_arrays(i1, i2, i3)"), - Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, nint, 2))) + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 4, 5, 6)), Row(Seq(1, null, 2))) ) checkAnswer( - df.select(concat_arrays($"s1", $"s2", $"s3")), + df.select(concat($"s1", $"s2", $"s3")), Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) ) checkAnswer( - df.selectExpr("concat_arrays(s1, s2, s3)"), + df.selectExpr("concat(s1, s2, s3)"), Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) ) // Null test cases checkAnswer( - df.select(concat_arrays($"i1", $"in")), + df.select(concat($"i1", $"in")), Seq(Row(null), Row(null)) ) checkAnswer( - df.select(concat_arrays($"in", $"i1")), + df.select(concat($"in", $"i1")), Seq(Row(null), Row(null)) ) checkAnswer( - df.select(concat_arrays($"s1", $"sn")), + df.select(concat($"s1", $"sn")), Seq(Row(null), Row(null)) ) checkAnswer( - df.select(concat_arrays($"sn", $"s1")), + df.select(concat($"sn", $"s1")), Seq(Row(null), Row(null)) ) // Type error test cases intercept[AnalysisException] { - df.select(concat_arrays($"i1", $"s1")) + df.select(concat($"i1", $"s1")) } intercept[AnalysisException] { - df.select(concat_arrays(lit("a"), lit("b"))) - } - intercept[AnalysisException] { - df.selectExpr("concat_arrays(i1, i2, null)") + df.selectExpr("concat(i1, i2, null)") } } From 11205af476143c188106bb9d42e45fb564d77c14 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 20:41:48 +0200 Subject: [PATCH 05/18] [SPARK-23736][SQL] Adding new line at the end of the unresolved.scala file. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 8cc5168e7e2e..6c494672032c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -491,4 +491,4 @@ case class UnresolvedConcat(children: Seq[Expression]) extends Expression override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false -} \ No newline at end of file +} From 753499d1784bdaf7c96f67dfbc3d0ff5c1e955a9 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 26 Mar 2018 23:43:09 +0200 Subject: [PATCH 06/18] [SPARK-23736][SQL] Fixing failing unit test from DDLSuite. --- .../apache/spark/sql/catalyst/analysis/unresolved.scala | 9 +++++++++ .../apache/spark/sql/execution/command/DDLSuite.scala | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6c494672032c..310c9c547c5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -485,6 +485,15 @@ case class UnresolvedOrdinal(ordinal: Int) * Concatenates multiple columns of the same type into one. * @param children Could be string, binary or array expressions */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) case class UnresolvedConcat(children: Seq[Expression]) extends Expression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4df8fbfe1c0d..029ee6532636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1664,10 +1664,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // STRING checkAnswer( sql("DESCRIBE FUNCTION 'concat'"), - Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: + Row("Class: org.apache.spark.sql.catalyst.analysis.UnresolvedConcat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( From 2efdd771d3d4cc738060d225f621179a44259ebc Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 27 Mar 2018 14:18:43 +0200 Subject: [PATCH 07/18] [SPARK-23736][SQL] Changing method styling according to the standards. --- .../sql/catalyst/expressions/Expression.scala | 20 ++++++++-------- .../expressions/collectionOperations.scala | 23 ++++++++----------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8abc9e0fb463..f9ee59367c42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -714,8 +714,7 @@ trait NullSafeEvaluation extends Expression * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also * override this. */ - override def eval(input: InternalRow): Any = - { + override def eval(input: InternalRow): Any = { val values = children.toStream.map(_.eval(input)) if (values.contains(null)) { null @@ -740,9 +739,9 @@ trait NullSafeEvaluation extends Expression * @param f accepts a sequence of variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: Seq[String] => String): ExprCode = { + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { nullSafeCodeGen(ctx, ev, values => { s"${ev.value} = ${f(values)};" }) @@ -757,9 +756,9 @@ trait NullSafeEvaluation extends Expression * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: Seq[String] => String): ExprCode = { + ctx: CodegenContext, + ev: ExprCode, + f: Seq[String] => String): ExprCode = { val gens = children.map(_.genCode(ctx)) val resultCode = f(gens.map(_.value)) @@ -767,9 +766,8 @@ trait NullSafeEvaluation extends Expression val nullSafeEval = children.zip(gens).foldRight(s""" ${ev.isNull} = false; // resultCode could change nullability. $resultCode - """) { - case ((child, gen), acc) => - gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc) + """) { case ((child, gen), acc) => + gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc) } ev.copy(code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bea72b39ecb1..8d50c058b125 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -347,7 +347,6 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS .headOption.map(_.dataType.asInstanceOf[ArrayType]) .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) - override protected def nullSafeEval(inputs: Seq[Any]): Any = { val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) new GenericArrayData(elements) @@ -365,9 +364,8 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS } private def genCodeForNumberOfElements( - ctx: CodegenContext, - elements: Seq[String] - ) : (String, String) = { + ctx: CodegenContext, + elements: Seq[String]) : (String, String) = { val variableName = ctx.freshName("numElements") val code = elements .map(el => s"$variableName += $el.numElements();") @@ -376,11 +374,10 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS } private def genCodeForConcatOfPrimitiveElements( - ctx: CodegenContext, - elementType: DataType, - elements: Seq[String], - arrayDataName: String - ): String = { + ctx: CodegenContext, + elementType: DataType, + elements: Seq[String], + arrayDataName: String): String = { val arrayName = ctx.freshName("array") val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") @@ -424,14 +421,12 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS |$assignments |$arrayDataName = $tempArrayDataName; """.stripMargin - } private def genCodeForConcatOfComplexElements( - ctx: CodegenContext, - elements: Seq[String], - arrayDataName: String - ): String = { + ctx: CodegenContext, + elements: Seq[String], + arrayDataName: String): String = { val genericArrayClass = classOf[GenericArrayData].getName val arrayName = ctx.freshName("arrayObject") val counter = ctx.freshName("counter") From fd84bee602b41ad463dbee5acd41e3d34581e7ac Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 27 Mar 2018 15:25:27 +0200 Subject: [PATCH 08/18] [SPARK-23736][SQL] Changing data type to ArrayType(StringType) for the case when no children are provided. --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8d50c058b125..c4e7c6e9d1c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -327,8 +327,7 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS } } - private def checkInputDataTypesAreArrays(): TypeCheckResult = - { + private def checkInputDataTypesAreArrays(): TypeCheckResult = { val mismatches = children.zipWithIndex.collect { case (child, idx) if !ArrayType.acceptsType(child.dataType) => s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + @@ -342,10 +341,11 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS } } - override def dataType: ArrayType = + override def dataType: ArrayType = { children .headOption.map(_.dataType.asInstanceOf[ArrayType]) - .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType(StringType)) + } override protected def nullSafeEval(inputs: Seq[Any]): Any = { val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) From 116f91f5490fb54b4b08e785d881c310e67a2c99 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 27 Mar 2018 22:57:55 +0200 Subject: [PATCH 09/18] [SPARK-23736][SQL] Fixing a SparkR unit test by filtering out UnresolvedConcat from dataset of functions. --- R/pkg/tests/fulltests/test_sparkSQL.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 439191adb23e..364d9173ff63 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3501,7 +3501,8 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { expect_true(nrow(f) >= 200) # 250 expect_equal(colnames(f), c("name", "database", "description", "className", "isTemporary")) - expect_equal(take(orderBy(f, "className"), 1)$className, + fe <- filter(f, startsWith(f$className, "org.apache.spark.sql.catalyst.expressions")) + expect_equal(take(orderBy(fe, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") expect_error(listFunctions("foo_db"), "Error in listFunctions : analysis error - Database 'foo_db' does not exist") From 090929f5e35e1f8aec3e83484cc8227a0436e5d7 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Sat, 7 Apr 2018 00:17:57 +0200 Subject: [PATCH 10/18] [SPARK-23736][SQL] Merging string concat and array concat into one expression. --- R/pkg/tests/fulltests/test_sparkSQL.R | 3 +- python/pyspark/sql/functions.py | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 2 - .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../sql/catalyst/analysis/unresolved.scala | 21 -- .../sql/catalyst/expressions/Expression.scala | 85 ------ .../expressions/collectionOperations.scala | 242 ++++++++++-------- .../expressions/stringExpressions.scala | 81 ------ .../CollectionExpressionsSuite.scala | 25 +- .../org/apache/spark/sql/functions.scala | 6 +- .../inputs/typeCoercion/native/concat.sql | 62 +++++ .../typeCoercion/native/concat.sql.out | 78 ++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 20 +- 14 files changed, 317 insertions(+), 320 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 364d9173ff63..439191adb23e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3501,8 +3501,7 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { expect_true(nrow(f) >= 200) # 250 expect_equal(colnames(f), c("name", "database", "description", "className", "isTemporary")) - fe <- filter(f, startsWith(f$className, "org.apache.spark.sql.catalyst.expressions")) - expect_equal(take(orderBy(fe, "className"), 1)$className, + expect_equal(take(orderBy(f, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") expect_error(listFunctions("foo_db"), "Error in listFunctions : analysis error - Database 'foo_db' does not exist") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 57e60868bc8e..9fb0dd2f02ae 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1823,7 +1823,7 @@ def array_contains(col, value): def concat(*cols): """ Concatenates multiple input columns together into a single column. - The function works with strings, binary columns and arrays of the same time. + The function works with strings, binary and compatible array columns. >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd1b20e7c951..e821e96522f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -827,8 +827,6 @@ class Analyzer( result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) - case UnresolvedConcat(children) if children.forall(_.resolved) => - ResolveConcat(children) case _ => e.mapChildren(resolve(_, q)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c7973521a310..f99cdda95a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -407,7 +407,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[UnresolvedConcat]("concat"), + expression[Concat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec7e7761dc4c..510bae88a49f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -504,6 +504,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 310c9c547c5f..a65f58fa61ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -480,24 +480,3 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } - -/** - * Concatenates multiple columns of the same type into one. - * @param children Could be string, binary or array expressions - */ -@ExpressionDescription( - usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - [1,2,3,4,5,6] - """) -case class UnresolvedConcat(children: Seq[Expression]) extends Expression - with Unevaluable { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3506da1e608..38caf67d465d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -700,88 +700,3 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression - -/** - * The trait covers logic for performing null safe evaluation and code generation. - */ -trait NullSafeEvaluation extends Expression -{ - override def foldable: Boolean = children.forall(_.foldable) - - override def nullable: Boolean = children.exists(_.nullable) - - /** - * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. - * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also - * override this. - */ - override def eval(input: InternalRow): Any = { - val values = children.toStream.map(_.eval(input)) - if (values.contains(null)) { - null - } else { - nullSafeEval(values) - } - } - - /** - * Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep - * the default nullability, they can override this method to save null-check code. If we need - * full control of evaluation process, we should override [[eval]]. - */ - protected def nullSafeEval(inputs: Seq[Any]): Any = - sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval") - - /** - * Short hand for generating of null save evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts a sequence of variable names and returns Java code to compute the output. - */ - protected def defineCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: Seq[String] => String): ExprCode = { - nullSafeCodeGen(ctx, ev, values => { - s"${ev.value} = ${f(values)};" - }) - } - - /** - * Called by expressions to generate null safe evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function that accepts a sequence of non-null evaluation result names of children - * and returns Java code to compute the output. - */ - protected def nullSafeCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: Seq[String] => String): ExprCode = { - val gens = children.map(_.genCode(ctx)) - val resultCode = f(gens.map(_.value)) - - if (nullable) { - val nullSafeEval = children.zip(gens).foldRight(s""" - ${ev.isNull} = false; // resultCode could change nullability. - $resultCode - """) { case ((child, gen), acc) => - gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc) - } - - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) - } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; - ${gens.map(_.code).mkString("\n")} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c4e7c6e9d1c4..d49b72ff1a4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -291,99 +292,123 @@ case class ArrayContains(left: Expression, right: Expression) } /** - * Replaces [[org.apache.spark.sql.catalyst.analysis.UnresolvedConcat UnresolvedConcat]]s - * with concrete concate expressions. - */ -object ResolveConcat -{ - def apply(children: Seq[Expression]): Expression = { - if (children.nonEmpty && ArrayType.acceptsType(children(0).dataType)) { - ConcatArrays(children) - } else { - Concat(children) - } - } -} - -/** - * Concatenates multiple arrays into one. + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. */ @ExpressionDescription( - usage = "_FUNC_(expr, ...) - Concatenates multiple arrays of the same type into one.", + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", examples = """ Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - [1,2,3,4,5,6] - """, - since = "2.4.0") -case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { - - override def checkInputDataTypes(): TypeCheckResult = { - val arrayCheck = checkInputDataTypesAreArrays - if(arrayCheck.isFailure) { - arrayCheck - } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") - } - } + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { - private def checkInputDataTypesAreArrays(): TypeCheckResult = { - val mismatches = children.zipWithIndex.collect { - case (child, idx) if !ArrayType.acceptsType(child.dataType) => - s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.simpleString} type." - } + val allowedTypes = Seq(StringType, BinaryType, ArrayType) - if (mismatches.isEmpty) { + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + + s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } } - override def dataType: ArrayType = { - children - .headOption.map(_.dataType.asInstanceOf[ArrayType]) - .getOrElse(ArrayType(StringType)) + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) + new GenericArrayData(elements) + } } - override protected def nullSafeEval(inputs: Seq[Any]): Any = { - val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) - new GenericArrayData(elements) - } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, arrays => { - val elementType = dataType.elementType - if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value) - } else { - genCodeForConcatOfComplexElements(ctx, arrays, ev.value) - } - }) + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType, evals.length) + } else { + genCodeForComplexArrayConcat(ctx, evals.length) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) } private def genCodeForNumberOfElements( ctx: CodegenContext, - elements: Seq[String]) : (String, String) = { + argsLength: Int) : (String, String) = { val variableName = ctx.freshName("numElements") - val code = elements - .map(el => s"$variableName += $el.numElements();") + val code = (0 until argsLength) + .map(idx => s"$variableName += args[$idx].numElements();") .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s) (code, variableName) } - private def genCodeForConcatOfPrimitiveElements( + private def nullArgumentProtection(argsLength: Int) : String = + { + (0 until argsLength) + .map(idx => s"if (args[$idx] == null) return null;") + .mkString("\n") + } + + private def genCodeForPrimitiveArrayConcat( ctx: CodegenContext, elementType: DataType, - elements: Seq[String], - arrayDataName: String): String = { + argsLength: Int): String = { val arrayName = ctx.freshName("array") val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") - val tempArrayDataName = ctx.freshName("tempArrayData") + val arrayDataName = ctx.freshName("arrayData") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, argsLength) val unsafeArraySizeInBytes = s""" |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + @@ -394,61 +419,70 @@ case class ConcatArrays(children: Seq[Expression]) extends Expression with NullS val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val assignments = elements.map { el => + val assignments = (0 until argsLength).map { idx => s""" - |for (int z = 0; z < $el.numElements(); z++) { - | if ($el.isNullAt(z)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | $el.get$primitiveValueTypeName(z) - | ); - | } - | $counter++; - |} + |for (int z = 0; z < args[$idx].numElements(); z++) { + | if (args[$idx].isNullAt(z)) { + | $arrayDataName.setNullAt($counter); + | } else { + | $arrayDataName.set$primitiveValueTypeName( + | $counter, + | args[$idx].get$primitiveValueTypeName(z) + | ); + | } + | $counter++; + |} """.stripMargin }.mkString("\n") - s""" - |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); - |int $counter = 0; - |$assignments - |$arrayDataName = $tempArrayDataName; - """.stripMargin + s"""new Object() { + | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { + | ${nullArgumentProtection(argsLength)} + | $numElemCode + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[$arraySizeName]; + | UnsafeArrayData $arrayDataName = new UnsafeArrayData(); + | Platform.putLong($arrayName, $baseOffset, $numElemName); + | $arrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + | int $counter = 0; + | $assignments + | return $arrayDataName; + | } + |}""".stripMargin } - private def genCodeForConcatOfComplexElements( - ctx: CodegenContext, - elements: Seq[String], - arrayDataName: String): String = { + private def genCodeForComplexArrayConcat( + ctx: CodegenContext, + argsLength: Int): String = { val genericArrayClass = classOf[GenericArrayData].getName val arrayName = ctx.freshName("arrayObject") val counter = ctx.freshName("counter") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + val className = ctx.freshName("ComplexArrayConcat") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, argsLength) - val assignments = elements.map { el => + val assignments = (0 until argsLength).map { idx => s""" - |for (int z = 0; z < $el.numElements(); z++) { - | $arrayName[$counter] = $el.array()[z]; - | $counter++; - |} + |for (int z = 0; z < args[$idx].numElements(); z++) { + | $arrayName[$counter] = args[$idx].array()[z]; + | $counter++; + |} """.stripMargin }.mkString("\n") - s""" - |$numElemCode - |Object[] $arrayName = new Object[$numElemName]; - |int $counter = 0; - |$assignments - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin + s"""new Object() { + | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { + | ${nullArgumentProtection(argsLength)} + | $numElemCode + | Object[] $arrayName = new Object[$numElemName]; + | int $counter = 0; + | $assignments + | return new $genericArrayClass($arrayName); + | } + |}""".stripMargin } - override def prettyName: String = "concat" + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 22fbb8998ed8..503e64dcee65 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 78be7a20a184..42e64579dfbd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,7 +106,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } - test("Concat Arrays") { + test("Concat") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) val a2 = Literal.create(Seq(4, null, 6), ArrayType(IntegerType)) @@ -115,20 +115,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a5 = Literal.create(Seq.empty[String], ArrayType(StringType)) val an = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ConcatArrays(Seq(a0)), Seq(1, 2, 3)) - checkEvaluation(ConcatArrays(Seq(a0, a2)), Seq(1, 2, 3, 4, null, 6)) - checkEvaluation(ConcatArrays(Seq(a0, a1, a2)), Seq(1, 2, 3, 4, null, 6)) + checkEvaluation(Concat(Seq(a0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(a0, a2)), Seq(1, 2, 3, 4, null, 6)) + checkEvaluation(Concat(Seq(a0, a1, a2)), Seq(1, 2, 3, 4, null, 6)) - checkEvaluation(ConcatArrays(Seq(a3, a4)), Seq("a", "b", "c", "e", null)) - checkEvaluation(ConcatArrays(Seq(a3, a4, a5)), Seq("a", "b", "c", "e", null)) - checkEvaluation(ConcatArrays(Seq(a5)), Seq()) + checkEvaluation(Concat(Seq(a3, a4)), Seq("a", "b", "c", "e", null)) + checkEvaluation(Concat(Seq(a3, a4, a5)), Seq("a", "b", "c", "e", null)) + checkEvaluation(Concat(Seq(a5)), Seq()) - checkEvaluation(ConcatArrays(Seq()), Seq()) - checkEvaluation(ConcatArrays(Seq(a0, a0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(a0, a0)), Seq(1, 2, 3, 1, 2, 3)) - checkEvaluation(ConcatArrays(Seq(an)), null) - checkEvaluation(ConcatArrays(Seq(a3, an)), null) - checkEvaluation(ConcatArrays(Seq(an, a3)), null) - checkEvaluation(ConcatArrays(Seq(a3, an, a4)), null) + checkEvaluation(Concat(Seq(an)), null) + checkEvaluation(Concat(Seq(a3, an)), null) + checkEvaluation(Concat(Seq(an, a3)), null) + checkEvaluation(Concat(Seq(a3, an, a4)), null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5938c2d5bc9e..456c0db0da50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedConcat, UnresolvedFunction} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -3038,13 +3038,13 @@ object functions { /** * Concatenates multiple input columns together into a single column. - * The function works with strings, binary columns and arrays of the same time. + * The function works with strings, binary and compatible array columns. * * @group collection_funcs * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { UnresolvedConcat(exprs.map(_.expr)) } + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } /** * Creates a new row for each element in the given array or map column. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702f..db00a18f2e7e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec3..62befc5ca0f1 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a011fec74522..c0f605dd474d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -417,18 +417,27 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( - (Seq(1), Seq(2, 3, 4), Seq(5, 6), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), - (Seq(1, 0), Seq.empty[Int], Seq(2), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) checkAnswer( df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 0, 2))) + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) ) checkAnswer( df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 4, 5, 6)), Row(Seq(1, null, 2))) + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) ) checkAnswer( df.select(concat($"s1", $"s2", $"s3")), @@ -458,9 +467,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) // Type error test cases - intercept[AnalysisException] { - df.select(concat($"i1", $"s1")) - } intercept[AnalysisException] { df.selectExpr("concat(i1, i2, null)") } From 8abd1a8b92eee5b83c13a1969dcbfca7e6cb6a06 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Sat, 7 Apr 2018 14:50:44 +0200 Subject: [PATCH 11/18] [SPARK-23736][SQL] Adding more test cases --- .../expressions/collectionOperations.scala | 10 ++-- .../CollectionExpressionsSuite.scala | 60 ++++++++++++------- .../spark/sql/DataFrameFunctionsSuite.scala | 5 ++ .../sql/execution/command/DDLSuite.scala | 2 +- 4 files changed, 50 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d49b72ff1a4a..1da0e472a789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -394,9 +394,11 @@ case class Concat(children: Seq[Expression]) extends Expression { private def nullArgumentProtection(argsLength: Int) : String = { - (0 until argsLength) - .map(idx => s"if (args[$idx] == null) return null;") - .mkString("\n") + if (nullable) { + (0 until argsLength).map(idx => s"if (args[$idx] == null) return null;").mkString("\n") + } else { + "" + } } private def genCodeForPrimitiveArrayConcat( @@ -457,7 +459,7 @@ case class Concat(children: Seq[Expression]) extends Expression { val genericArrayClass = classOf[GenericArrayData].getName val arrayName = ctx.freshName("arrayObject") val counter = ctx.freshName("counter") - val className = ctx.freshName("ComplexArrayConcat") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, argsLength) val assignments = (0 until argsLength).map { idx => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 42e64579dfbd..b5126531c633 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -107,27 +107,43 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Concat") { - val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val a2 = Literal.create(Seq(4, null, 6), ArrayType(IntegerType)) - val a3 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val a4 = Literal.create(Seq("e", null), ArrayType(StringType)) - val a5 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val an = Literal.create(null, ArrayType(StringType)) - - checkEvaluation(Concat(Seq(a0)), Seq(1, 2, 3)) - checkEvaluation(Concat(Seq(a0, a2)), Seq(1, 2, 3, 4, null, 6)) - checkEvaluation(Concat(Seq(a0, a1, a2)), Seq(1, 2, 3, 4, null, 6)) - - checkEvaluation(Concat(Seq(a3, a4)), Seq("a", "b", "c", "e", null)) - checkEvaluation(Concat(Seq(a3, a4, a5)), Seq("a", "b", "c", "e", null)) - checkEvaluation(Concat(Seq(a5)), Seq()) - - checkEvaluation(Concat(Seq(a0, a0)), Seq(1, 2, 3, 1, 2, 3)) - - checkEvaluation(Concat(Seq(an)), null) - checkEvaluation(Concat(Seq(a3, an)), null) - checkEvaluation(Concat(Seq(an, a3)), null) - checkEvaluation(Concat(Seq(a3, an, a4)), null) + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Complex-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c0f605dd474d..fd8025c47a73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -417,6 +417,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") @@ -470,6 +471,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, i2, null)") } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 029ee6532636..cd786d205de4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1664,7 +1664,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // STRING checkAnswer( sql("DESCRIBE FUNCTION 'concat'"), - Row("Class: org.apache.spark.sql.catalyst.analysis.UnresolvedConcat") :: + Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: Row("Usage: concat(col1, col2, ..., colN) - " + "Returns the concatenation of col1, col2, ..., colN.") :: Nil From 367ee2241901225e7451d7280611cecf23be82f1 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Sun, 8 Apr 2018 00:22:29 +0200 Subject: [PATCH 12/18] [SPARK-23736][SQL] Optimizing null elements protection. --- .../expressions/collectionOperations.scala | 44 +++++++------------ 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1da0e472a789..e99458b36e96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -364,9 +364,9 @@ case class Concat(children: Seq[Expression]) extends Expression { ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, _) => val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrayConcat(ctx, elementType, evals.length) + genCodeForPrimitiveArrayConcat(ctx, elementType) } else { - genCodeForComplexArrayConcat(ctx, evals.length) + genCodeForComplexArrayConcat(ctx) } (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") } @@ -382,35 +382,28 @@ case class Concat(children: Seq[Expression]) extends Expression { """) } - private def genCodeForNumberOfElements( - ctx: CodegenContext, - argsLength: Int) : (String, String) = { + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { val variableName = ctx.freshName("numElements") - val code = (0 until argsLength) + val code = (0 until children.length) .map(idx => s"$variableName += args[$idx].numElements();") .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s) (code, variableName) } - private def nullArgumentProtection(argsLength: Int) : String = - { - if (nullable) { - (0 until argsLength).map(idx => s"if (args[$idx] == null) return null;").mkString("\n") - } else { - "" - } + private def nullArgumentProtection() : String = { + children.zipWithIndex + .filter(_._1.nullable) + .map(ci => s"if (args[${ci._2}] == null) return null;") + .mkString("\n") } - private def genCodeForPrimitiveArrayConcat( - ctx: CodegenContext, - elementType: DataType, - argsLength: Int): String = { + private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, elementType: DataType): String = { val arrayName = ctx.freshName("array") val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayDataName = ctx.freshName("arrayData") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, argsLength) + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) val unsafeArraySizeInBytes = s""" |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + @@ -421,7 +414,7 @@ case class Concat(children: Seq[Expression]) extends Expression { val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val assignments = (0 until argsLength).map { idx => + val assignments = (0 until children.length).map { idx => s""" |for (int z = 0; z < args[$idx].numElements(); z++) { | if (args[$idx].isNullAt(z)) { @@ -439,7 +432,7 @@ case class Concat(children: Seq[Expression]) extends Expression { s"""new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection(argsLength)} + | ${nullArgumentProtection()} | $numElemCode | $unsafeArraySizeInBytes | byte[] $arrayName = new byte[$arraySizeName]; @@ -453,16 +446,14 @@ case class Concat(children: Seq[Expression]) extends Expression { |}""".stripMargin } - private def genCodeForComplexArrayConcat( - ctx: CodegenContext, - argsLength: Int): String = { + private def genCodeForComplexArrayConcat(ctx: CodegenContext): String = { val genericArrayClass = classOf[GenericArrayData].getName val arrayName = ctx.freshName("arrayObject") val counter = ctx.freshName("counter") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, argsLength) + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val assignments = (0 until argsLength).map { idx => + val assignments = (0 until children.length).map { idx => s""" |for (int z = 0; z < args[$idx].numElements(); z++) { | $arrayName[$counter] = args[$idx].array()[z]; @@ -473,7 +464,7 @@ case class Concat(children: Seq[Expression]) extends Expression { s"""new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection(argsLength)} + | ${nullArgumentProtection()} | $numElemCode | Object[] $arrayName = new Object[$numElemName]; | int $counter = 0; @@ -483,7 +474,6 @@ case class Concat(children: Seq[Expression]) extends Expression { |}""".stripMargin } - override def toString: String = s"concat(${children.mkString(", ")})" override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" From 6bb33e6eeb2028d64b4258413523e02b08263059 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Thu, 12 Apr 2018 19:33:55 +0200 Subject: [PATCH 13/18] [SPARK-23736][SQL] Protection against the length limit of Java functions --- .../expressions/collectionOperations.scala | 104 ++++++++++++------ 1 file changed, 72 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e99458b36e96..ff257663f3ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -324,6 +324,8 @@ case class Concat(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + lazy val javaType: String = CodeGenerator.javaType(dataType) + override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -373,35 +375,62 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${javaType}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${javaType} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val variableName = ctx.freshName("numElements") - val code = (0 until children.length) - .map(idx => s"$variableName += args[$idx].numElements();") - .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s) - (code, variableName) + val tempVariableName = ctx.freshName("tempNumElements") + val numElementsConstant = ctx.freshName("numElements") + val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + + val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + + (s""" + |int[] $tempVariableName = new int[]{0}; + |$assignmentSection + |final int $numElementsConstant = $tempVariableName[0]; + """.stripMargin, + numElementsConstant) } - private def nullArgumentProtection() : String = { - children.zipWithIndex + private def nullArgumentProtection(ctx: CodegenContext) : String = { + val isNullVariable = ctx.freshName("isArrayNull") + val assignments = children + .zipWithIndex .filter(_._1.nullable) - .map(ci => s"if (args[${ci._2}] == null) return null;") - .mkString("\n") + .map(ci => s"$isNullVariable[0] |= args[${ci._2}] == null;") + + if (assignments.length > 0) { + val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "isNullArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("boolean[]", isNullVariable))) + + s""" + |boolean[] $isNullVariable = new boolean[]{false}; + |$assignmentSection; + |if ($isNullVariable[0]) return null; + """.stripMargin + } else { + "" + } } private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, elementType: DataType): String = { val arrayName = ctx.freshName("array") val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") - val arrayDataName = ctx.freshName("arrayData") + val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) @@ -418,37 +447,44 @@ case class Concat(children: Seq[Expression]) extends Expression { s""" |for (int z = 0; z < args[$idx].numElements(); z++) { | if (args[$idx].isNullAt(z)) { - | $arrayDataName.setNullAt($counter); + | $arrayData.setNullAt($counter[0]); | } else { - | $arrayDataName.set$primitiveValueTypeName( - | $counter, + | $arrayData.set$primitiveValueTypeName( + | $counter[0], | args[$idx].get$primitiveValueTypeName(z) | ); | } - | $counter++; + | $counter[0]++; |} """.stripMargin - }.mkString("\n") + } + val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "primitiveArrayConcat", + arguments = Seq( + (s"${javaType}[]", "args"), + ("UnsafeArrayData", arrayData), + ("int[]", counter))) s"""new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection()} + | ${nullArgumentProtection(ctx)} | $numElemCode | $unsafeArraySizeInBytes | byte[] $arrayName = new byte[$arraySizeName]; - | UnsafeArrayData $arrayDataName = new UnsafeArrayData(); + | UnsafeArrayData $arrayData = new UnsafeArrayData(); | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); - | int $counter = 0; - | $assignments - | return $arrayDataName; + | $arrayData.pointTo($arrayName, $baseOffset, $arraySizeName); + | int[] $counter = new int[]{0}; + | $assignmentSection + | return $arrayData; | } |}""".stripMargin } private def genCodeForComplexArrayConcat(ctx: CodegenContext): String = { val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") + val arrayData = ctx.freshName("arrayObjects") val counter = ctx.freshName("counter") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) @@ -456,20 +492,24 @@ case class Concat(children: Seq[Expression]) extends Expression { val assignments = (0 until children.length).map { idx => s""" |for (int z = 0; z < args[$idx].numElements(); z++) { - | $arrayName[$counter] = args[$idx].array()[z]; - | $counter++; + | $arrayData[$counter[0]] = args[$idx].array()[z]; + | $counter[0]++; |} """.stripMargin - }.mkString("\n") + } + val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("Object[]", arrayData), ("int[]", counter))) s"""new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection()} + | ${nullArgumentProtection(ctx)} | $numElemCode - | Object[] $arrayName = new Object[$numElemName]; - | int $counter = 0; - | $assignments - | return new $genericArrayClass($arrayName); + | Object[] $arrayData = new Object[$numElemName]; + | int[] $counter = new int[]{0}; + | $assignmentSection + | return new $genericArrayClass($arrayData); | } |}""".stripMargin } From 944e0c9f5312eab9581c961704e875d5bd629878 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Thu, 12 Apr 2018 23:02:25 +0200 Subject: [PATCH 14/18] [SPARK-23736][SQL] Adding test for the limit of Java function size. --- .../sql/catalyst/expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/CollectionExpressionsSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5406dbfe2a3d..8546f04e3e1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b5126531c633..3b4f9529978b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -124,6 +124,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai0, ai4)), null) checkEvaluation(Concat(Seq(ai4, ai0)), null) + // Test of handling the limit of Java function size + val iIndexes = 1 to 1200 + checkEvaluation(Concat(iIndexes.map(_ => ai2)), iIndexes.flatMap(_ => Seq(4, null, 5))) + // Complex-type elements val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) @@ -145,5 +149,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + // Test of handling the limit of Java function size + val sIndexes = 1 to 1200 + checkEvaluation(Concat(sIndexes.map(_ => as2)), sIndexes.flatMap(_ => Seq("d", null, "e"))) } } From 7f5124ba8752387b3e1d6c0922b551a2cba98356 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Fri, 13 Apr 2018 10:26:21 +0200 Subject: [PATCH 15/18] [SPARK-23736][SQL] Adding more tests --- .../catalyst/expressions/collectionOperations.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8546f04e3e1d..e442e675b366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -368,7 +368,7 @@ case class Concat(children: Seq[Expression]) extends Expression { val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { genCodeForPrimitiveArrayConcat(ctx, elementType) } else { - genCodeForComplexArrayConcat(ctx) + genCodeForComplexArrayConcat(ctx, elementType) } (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") } @@ -451,7 +451,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | } else { | $arrayData.set$primitiveValueTypeName( | $counter[0], - | args[$idx].get$primitiveValueTypeName(z) + | ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")} | ); | } | $counter[0]++; @@ -482,7 +482,7 @@ case class Concat(children: Seq[Expression]) extends Expression { |}""".stripMargin } - private def genCodeForComplexArrayConcat(ctx: CodegenContext): String = { + private def genCodeForComplexArrayConcat(ctx: CodegenContext, elementType: DataType): String = { val genericArrayClass = classOf[GenericArrayData].getName val arrayData = ctx.freshName("arrayObjects") val counter = ctx.freshName("counter") @@ -492,7 +492,7 @@ case class Concat(children: Seq[Expression]) extends Expression { val assignments = (0 until children.length).map { idx => s""" |for (int z = 0; z < args[$idx].numElements(); z++) { - | $arrayData[$counter[0]] = args[$idx].array()[z]; + | $arrayData[$counter[0]] = ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")}; | $counter[0]++; |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fd8025c47a73..20959b8a5aa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -422,6 +422,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + // Simple test cases checkAnswer( df.selectExpr("array(1, 2, 3L)"), @@ -436,6 +438,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(concat($"i1", $"i2", $"i3")), Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) checkAnswer( df.selectExpr("concat(array(1, null), i2, i3)"), Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) @@ -448,6 +454,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("concat(s1, s2, s3)"), Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) // Null test cases checkAnswer( From 0201e4b64da08156ff4d96d55d51697b79807028 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 16 Apr 2018 18:01:20 +0200 Subject: [PATCH 16/18] [SPARK-23736][SQL] Checks of max array size + Rewriting codegen using for loops. --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- .../catalyst/expressions/UnsafeArrayData.java | 10 + .../expressions/collectionOperations.scala | 172 ++++++++---------- .../CollectionExpressionsSuite.scala | 10 +- 4 files changed, 95 insertions(+), 103 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd..ef0f78d95d1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c2833553..d5d934bc91ca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -56,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e442e675b366..7855b007cfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -306,6 +306,8 @@ case class ArrayContains(left: Expression, right: Expression) """) case class Concat(children: Seq[Expression]) extends Expression { + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + val allowedTypes = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { @@ -341,8 +343,20 @@ case class Concat(children: Seq[Expression]) extends Expression { if (inputs.contains(null)) { null } else { - val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) - new GenericArrayData(elements) + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) } } @@ -366,9 +380,9 @@ case class Concat(children: Seq[Expression]) extends Expression { ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, _) => val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrayConcat(ctx, elementType) + genCodeForPrimitiveArrays(ctx, elementType) } else { - genCodeForComplexArrayConcat(ctx, elementType) + genCodeForNonPrimitiveArrays(ctx, elementType) } (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") } @@ -385,48 +399,34 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val tempVariableName = ctx.freshName("tempNumElements") - val numElementsConstant = ctx.freshName("numElements") - val assignments = (0 until children.length) - .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") - - val assignmentSection = ctx.splitExpressions( - expressions = assignments, - funcName = "complexArrayConcat", - arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) - - (s""" - |int[] $tempVariableName = new int[]{0}; - |$assignmentSection - |final int $numElementsConstant = $tempVariableName[0]; - """.stripMargin, - numElementsConstant) - } - - private def nullArgumentProtection(ctx: CodegenContext) : String = { - val isNullVariable = ctx.freshName("isArrayNull") - val assignments = children - .zipWithIndex - .filter(_._1.nullable) - .map(ci => s"$isNullVariable[0] |= args[${ci._2}] == null;") + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with $numElements" + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin - if (assignments.length > 0) { - val assignmentSection = ctx.splitExpressions( - expressions = assignments, - funcName = "isNullArrayConcat", - arguments = Seq((s"${javaType}[]", "args"), ("boolean[]", isNullVariable))) + (code, numElements) + } + private def nullArgumentProtection() : String = { + if (nullable) { s""" - |boolean[] $isNullVariable = new boolean[]{false}; - |$assignmentSection; - |if ($isNullVariable[0]) return null; + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} """.stripMargin } else { "" } } - private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { val arrayName = ctx.freshName("array") val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") @@ -435,83 +435,69 @@ case class Concat(children: Seq[Expression]) extends Expression { val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) val unsafeArraySizeInBytes = s""" - |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + - |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( - | ${elementType.defaultSize} * $numElemName - |); + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with $arraySizeName bytes" + + | " of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData."); + |} """.stripMargin val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val assignments = (0 until children.length).map { idx => - s""" - |for (int z = 0; z < args[$idx].numElements(); z++) { - | if (args[$idx].isNullAt(z)) { - | $arrayData.setNullAt($counter[0]); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter[0], - | ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")} - | ); - | } - | $counter[0]++; - |} - """.stripMargin - } - val assignmentSection = ctx.splitExpressions( - expressions = assignments, - funcName = "primitiveArrayConcat", - arguments = Seq( - (s"${javaType}[]", "args"), - ("UnsafeArrayData", arrayData), - ("int[]", counter))) - - s"""new Object() { + + s""" + |new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection(ctx)} + | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[$arraySizeName]; + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[(int)$arraySizeName]; | UnsafeArrayData $arrayData = new UnsafeArrayData(); | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, $arraySizeName); - | int[] $counter = new int[]{0}; - | $assignmentSection + | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } | return $arrayData; | } - |}""".stripMargin + |}""".stripMargin.stripPrefix("\n") } - private def genCodeForComplexArrayConcat(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { val genericArrayClass = classOf[GenericArrayData].getName val arrayData = ctx.freshName("arrayObjects") val counter = ctx.freshName("counter") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val assignments = (0 until children.length).map { idx => - s""" - |for (int z = 0; z < args[$idx].numElements(); z++) { - | $arrayData[$counter[0]] = ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")}; - | $counter[0]++; - |} - """.stripMargin - } - val assignmentSection = ctx.splitExpressions( - expressions = assignments, - funcName = "complexArrayConcat", - arguments = Seq((s"${javaType}[]", "args"), ("Object[]", arrayData), ("int[]", counter))) - - s"""new Object() { + s""" + |new Object() { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { - | ${nullArgumentProtection(ctx)} + | ${nullArgumentProtection()} | $numElemCode - | Object[] $arrayData = new Object[$numElemName]; - | int[] $counter = new int[]{0}; - | $assignmentSection + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } | return new $genericArrayClass($arrayData); | } - |}""".stripMargin + |}""".stripMargin.stripPrefix("\n") } override def toString: String = s"concat(${children.mkString(", ")})" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3b4f9529978b..23d37e4b59d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -124,11 +124,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai0, ai4)), null) checkEvaluation(Concat(Seq(ai4, ai0)), null) - // Test of handling the limit of Java function size - val iIndexes = 1 to 1200 - checkEvaluation(Concat(iIndexes.map(_ => ai2)), iIndexes.flatMap(_ => Seq(4, null, 5))) - - // Complex-type elements + // Non-primitive-type elements val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) @@ -149,9 +145,5 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) - - // Test of handling the limit of Java function size - val sIndexes = 1 to 1200 - checkEvaluation(Concat(sIndexes.map(_ => as2)), sIndexes.flatMap(_ => Seq("d", null, "e"))) } } From f2a67e82880896bf7c09d3067f7d1699c43d2505 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Tue, 17 Apr 2018 15:19:10 +0200 Subject: [PATCH 17/18] [SPARK-23736][SQL] Fixing exception messages --- .../sql/catalyst/expressions/collectionOperations.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ea951f33ce22..6db33b620b64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -472,7 +472,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | $numElements += args[z].numElements(); |} |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with $numElements" + + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); |} """.stripMargin @@ -505,8 +505,9 @@ case class Concat(children: Seq[Expression]) extends Expression { | $numElemName, | ${elementType.defaultSize}); |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with $arraySizeName bytes" + - | " of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData."); + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); |} """.stripMargin val baseOffset = Platform.BYTE_ARRAY_OFFSET @@ -517,7 +518,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes + | $unsafeArraySizeInBytes | byte[] $arrayName = new byte[(int)$arraySizeName]; | UnsafeArrayData $arrayData = new UnsafeArrayData(); | Platform.putLong($arrayName, $baseOffset, $numElemName); From 8a125d9f536269f2aa2f659e9acc99c31e7854e1 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 11:17:42 +0200 Subject: [PATCH 18/18] [SPARK-23736][SQL] Small refactoring --- .../sql/catalyst/expressions/collectionOperations.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6db33b620b64..696fdc56f027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -395,6 +395,7 @@ case class Concat(children: Seq[Expression]) extends Expression { lazy val javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) override def eval(input: InternalRow): Any = dataType match { @@ -455,11 +456,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${javaType}[]", args) :: Nil) + extraArguments = (s"$javaType[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${javaType} ${ev.value} = $concatenator.concat($args); + $javaType ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -515,7 +516,7 @@ case class Concat(children: Seq[Expression]) extends Expression { s""" |new Object() { - | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { + | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode | $unsafeArraySizeInBytes @@ -551,7 +552,7 @@ case class Concat(children: Seq[Expression]) extends Expression { s""" |new Object() { - | public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) { + | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode | Object[] $arrayData = new Object[(int)$numElemName];