From 134e9e467c0544f9f98e8ca864487ef57504b6af Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 9 Sep 2019 10:10:27 +0200 Subject: [PATCH 01/27] [SPARK-29020] Improving array_sort behaviour --- .../expressions/collectionOperations.scala | 38 ++++++++++++------- .../CollectionExpressionsSuite.scala | 21 +++++----- .../org/apache/spark/sql/functions.scala | 18 +++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++++- 4 files changed, 75 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5314821ea3a5..c976751acafc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -902,33 +902,45 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** - * Sorts the input array in ascending order according to the natural ordering of + * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must - be orderable. Null elements will be placed at the end of the returned array. + be orderable. Null elements will be placed at the beginning of the returned array in ascending + order or at the end of the returned array in descending order. """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a","b","c","d",null] + [null,"a","b","c","d"] """, since = "2.4.0") // scalastyle:on line.size.limit -case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { +case class ArraySort(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + def this(e: Expression) = this(e, Literal(true)) - override def arrayExpression: Expression = child - override def nullOrder: NullOrder = NullOrder.Greatest + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } case ArrayType(dt, _) => val dtSimple = dt.catalogString TypeCheckResult.TypeCheckFailure( @@ -937,12 +949,12 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") } - override def nullSafeEval(array: Any): Any = { - sortEval(array, true) + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + nullSafeCodeGen(ctx, ev, (c, order) => sortCodegen(ctx, ev, c, order)) } override def prettyName: String = "array_sort" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 603073b40d7a..400ad320b72d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -364,15 +364,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) - checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) - checkEvaluation(ArraySort(a1), Seq[Integer]()) - checkEvaluation(ArraySort(a2), Seq("a", "b")) - checkEvaluation(ArraySort(a3), Seq("a", "b", null)) - checkEvaluation(ArraySort(a4), Seq(d1, d2)) - checkEvaluation(ArraySort(a5), Seq(null, null)) - checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) - checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) - checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) + checkEvaluation(new ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a0, Literal(false)), Seq(3, 2, 1)) + checkEvaluation(new ArraySort(a1), Seq[Integer]()) + checkEvaluation(new ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a2, Literal(false)), Seq("b", "a")) + checkEvaluation(new ArraySort(a3), Seq(null, "a", "b")) + checkEvaluation(ArraySort(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(new ArraySort(a4), Seq(d1, d2)) + checkEvaluation(new ArraySort(a5), Seq(null, null)) + checkEvaluation(new ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(new ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(new ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 84e0eaff2d42..1a088e2d0136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3328,13 +3328,25 @@ object functions { } /** - * Sorts the input array in ascending order. The elements of the input array must be orderable. - * Null elements will be placed at the end of the returned array. + * Sorts the input array for the given column in ascending order, + * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + def array_sort(e: Column): Column = array_sort(e, asc = true) + + /** + * Sorts the input array for the given column in ascending or descending order, + * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column, asc: Boolean): Column = withExpr { ArraySort(e.expr, lit(asc).expr) } /** * Remove all elements that equal to element from the given array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d044638db57..c2e7bcce9a9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -368,6 +368,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) + + checkAnswer( + df.select(array_sort($"a", false), array_sort($"b", false)), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( df.selectExpr("array_sort(a)", "array_sort(b)"), Seq( @@ -376,9 +385,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null)) ) + checkAnswer( + df.selectExpr("array_sort(a, false)", "array_sort(b, false)"), + Seq( + Row(Seq(3, 2, 1), Seq("c", "b", "a")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( df2.selectExpr("array_sort(a)"), - Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + Seq(Row(Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)))) + ) + + checkAnswer( + df2.selectExpr("array_sort(a, false)"), + Seq(Row(Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) ) assert(intercept[AnalysisException] { From aeee71cb4c2bf5306815fb1919a47840e1f7ce92 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Tue, 10 Sep 2019 08:12:02 +0200 Subject: [PATCH 02/27] [SPARK-29020] [SQL] Keep array_sort original behaviour --- .../sql/catalyst/expressions/collectionOperations.scala | 6 ++++-- .../catalyst/expressions/CollectionExpressionsSuite.scala | 2 +- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 ++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c976751acafc..6b74fdb19928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -915,7 +915,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - [null,"a","b","c","d"] + ["a","b","c","d",null] """, since = "2.4.0") // scalastyle:on line.size.limit @@ -930,7 +930,9 @@ case class ArraySort(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def arrayExpression: Expression = base - override def nullOrder: NullOrder = NullOrder.Least + override def nullOrder: NullOrder = { + if(ascendingOrder == Literal(true)) NullOrder.Greatest else NullOrder.Least + } override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 400ad320b72d..b8423bdcc745 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -369,7 +369,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new ArraySort(a1), Seq[Integer]()) checkEvaluation(new ArraySort(a2), Seq("a", "b")) checkEvaluation(ArraySort(a2, Literal(false)), Seq("b", "a")) - checkEvaluation(new ArraySort(a3), Seq(null, "a", "b")) + checkEvaluation(new ArraySort(a3), Seq("a", "b", null)) checkEvaluation(ArraySort(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(new ArraySort(a4), Seq(d1, d2)) checkEvaluation(new ArraySort(a5), Seq(null, null)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1a088e2d0136..b9593e26aa9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3330,7 +3330,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. - * Null elements will be placed at the beginning of the returned array. + * Null elements will be placed at the end of the returned array. * * @group collection_funcs * @since 2.4.0 @@ -3340,9 +3340,7 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. - * Null elements will be placed at the beginning of the returned array in ascending order or - * at the end of the returned array in descending order. - * + * Null elements will be placed at the end of the returned array in descending / ascending order * @group collection_funcs * @since 2.4.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c2e7bcce9a9f..f7384cb0af0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -395,7 +395,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df2.selectExpr("array_sort(a)"), - Seq(Row(Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)))) + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) ) checkAnswer( From ebde5444f8b593326d783d10d66e0b932cbcc071 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Tue, 10 Sep 2019 08:24:50 +0200 Subject: [PATCH 03/27] [SPARK-29020] [SQL] ascending parameter is now a Column --- .../sql/catalyst/expressions/collectionOperations.scala | 9 +++++---- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 +++--- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6b74fdb19928..0e0738d0f6d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -908,9 +908,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression) // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must - be orderable. Null elements will be placed at the beginning of the returned array in ascending - order or at the end of the returned array in descending order. + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array + in descending / ascending order """, examples = """ Examples: @@ -931,7 +932,7 @@ case class ArraySort(base: Expression, ascendingOrder: Expression) override def arrayExpression: Expression = base override def nullOrder: NullOrder = { - if(ascendingOrder == Literal(true)) NullOrder.Greatest else NullOrder.Least + if (ascendingOrder == Literal(true)) NullOrder.Greatest else NullOrder.Least } override def checkInputDataTypes(): TypeCheckResult = base.dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b9593e26aa9d..0ce323a69c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3335,16 +3335,16 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = array_sort(e, asc = true) + def array_sort(e: Column): Column = array_sort(e, asc = lit(true)) /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. * Null elements will be placed at the end of the returned array in descending / ascending order * @group collection_funcs - * @since 2.4.0 + * @since 3.0.0 */ - def array_sort(e: Column, asc: Boolean): Column = withExpr { ArraySort(e.expr, lit(asc).expr) } + def array_sort(e: Column, asc: Column): Column = withExpr { ArraySort(e.expr, asc.expr) } /** * Remove all elements that equal to element from the given array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f7384cb0af0f..882ea0183065 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -370,7 +370,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) checkAnswer( - df.select(array_sort($"a", false), array_sort($"b", false)), + df.select(array_sort($"a", lit(false)), array_sort($"b", lit(false))), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), Row(Seq.empty[Int], Seq.empty[String]), From 3f4c32837852175f05c42ae0c58559bdc564d220 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Sun, 15 Sep 2019 15:59:01 +0200 Subject: [PATCH 04/27] Array sorting as HOF with Integers Asc and Desc --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 36 +++++- .../expressions/higherOrderFunctions.scala | 110 +++++++++++++++++- .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++- 5 files changed, 204 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d5728b902757..24000210c2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -454,6 +454,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[ArraySorting]("array_new_sort"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0e0738d0f6d0..795ff67ce124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -712,7 +712,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient private lazy val lt: Comparator[Any] = { + @transient lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -963,6 +963,40 @@ case class ArraySort(base: Expression, ascendingOrder: Expression) override def prettyName: String = "array_sort" } +//scalastyle:off + +case class ArraySortF(child: Expression, function: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + function.asInstanceOf[LambdaFunction].arguments + function.asInstanceOf[LambdaFunction].function + function.asInstanceOf[LambdaFunction].dataType + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + /** * Returns a random permutation of the given array. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ed26bb375de2..622c0521fff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Comparator import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods +import scala.util.Sorting + /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ @@ -269,7 +272,7 @@ case class ArrayTransform( override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval - val result = new GenericArrayData(new Array[Any](arr.numElements)) + var result = new GenericArrayData(new Array[Any](arr.numElements)) var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) @@ -285,6 +288,109 @@ case class ArrayTransform( override def prettyName: String = "transform" } +//scalastyle:off +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ + +case class ArraySorting( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { + val ArrayType(elementType, containsNull) = argument.dataType + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (elementType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + } + + @transient lazy val (firstParam, secondParam) = { + val LambdaFunction(_, (firstParam: NamedLambdaVariable) +: tail, _) = function + val secondParam = tail.head.asInstanceOf[NamedLambdaVariable] + (firstParam, secondParam) + } + + @transient lazy val elementType: DataType = + function.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + + @transient lazy val customComparator: Comparator[Any] = { + (o1: Any, o2: Any) => { + firstParam.value.set(o2) + secondParam.value.set(o1) + f.eval(inputRow).asInstanceOf[Int] + } + } + + + + for(i <- 0 until arr.numElements()){ + result.update(i, arr.get(i, firstParam.dataType)) + } + val p = sortEval(result, customComparator) + val t = 2 + p + + } + + def sortEval(array: Any, comparator: Comparator[Any]): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](firstParam.dataType) + if (firstParam.dataType!= NullType) { + java.util.Arrays.sort(data, comparator) + } + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + override def prettyName: String = "array_new_sort" + + // var i = 0 + // while (i < arr.numElements - 1) { + // firstParam.value.set(arr.get(i, firstParam.dataType)) + // secondParam.value.set(i) + // result.update(i, f.eval(inputRow)) + // i += 1 + // } + + // for(i <- 0 until arr.numElements() -1){ + // result.update(i, arr.get(i, firstParam.dataType)) + // } + // for(i <- 0 until arr.numElements()-1) { + // for(j<-0 until arr.numElements - i-1) { + // firstParam.value.set(arr.get(j, firstParam.dataType)) + // secondParam.value.set(arr.get(j + 1, secondParam.dataType)) + // if(f.eval(inputRow) == -1) { + // var temp = arr.get(j, firstParam.dataType) + // arr.update(j, arr.get(j+1, firstParam.dataType)) + // arr.update(j+1 ,temp) + // } + // } + // } + + // result + +} + + +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ +/***************************************/ + /** * Filters entries in a map using the provided function. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0ce323a69c68..fdcc84f0878a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3346,6 +3346,16 @@ object functions { */ def array_sort(e: Column, asc: Column): Column = withExpr { ArraySort(e.expr, asc.expr) } + private def createLambda(f: (Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } +//scalastyle:off + def array_sort_f(e: Column, f: (Column, Column) => Column): Column = withExpr { ArraySortF(e.expr, createLambda(f)) } + + /** * Remove all elements that equal to element from the given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 882ea0183065..d380a1cc7986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -22,7 +22,6 @@ import java.sql.{Date, Timestamp} import java.util.TimeZone import scala.util.Random - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -33,6 +32,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import scala.collection.mutable.ArrayBuffer + /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ @@ -311,8 +312,55 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(2)) ) } +//scalastyle:off + test("array_sort with lambda functions") { + + spark.udf.register("fAsc", (x: Int, y: Int) => { + if(x < y) -1 + else if(x == y) 0 + else 1 + }) + + spark.udf.register("fDesc", (x: Int, y: Int) => { + if(x < y) 1 + else if(x == y) 0 + else -1 + }) + + spark.udf.register("fStringAsc", (x: String, y: String) => { + if(x < y) -1 + else if(x == y) 0 + else 1 + }) + + val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") + val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") + + checkAnswer( + df1.selectExpr("array_new_sort(a, (b, i) -> fAsc(b,i))"), + Seq( + Row(Seq(5, 3, 2, 2, 1)))) + + checkAnswer( + df1.selectExpr("array_new_sort(a, (b, i) -> fDesc(b,i))"), + Seq( + Row(Seq(1, 2, 2, 3, 5)))) + +// checkAnswer( +// df2.selectExpr("array_new_sort(a, (b, c) -> fStringAsc(b,c))"), +// Seq( +// Row(Seq("dc", "bc", "ab")))) + } + + + + + + + + - test("sort_array/array_sort functions") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), From 7264fc530fc0d85964b9b5428dbe99e9fa77aab8 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Sun, 15 Sep 2019 18:10:18 +0200 Subject: [PATCH 05/27] 4/6 cases working --- .../expressions/higherOrderFunctions.scala | 53 ++++++------------- .../spark/sql/DataFrameFunctionsSuite.scala | 41 ++++++++++---- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 622c0521fff4..536a5a2c668d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -301,13 +301,16 @@ case class ArraySorting( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + @transient lazy val elementType: DataType = + argument.dataType.asInstanceOf[ArrayType].elementType + override def dataType: ArrayType = ArrayType(elementType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { val ArrayType(elementType, containsNull) = argument.dataType function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, (elementType, containsNull) :: (elementType, false) :: Nil)) + copy(function = f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) case _ => copy(function = f(function, (elementType, containsNull) :: Nil)) } @@ -319,36 +322,36 @@ case class ArraySorting( (firstParam, secondParam) } - @transient lazy val elementType: DataType = - function.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval - val result = new GenericArrayData(new Array[Any](arr.numElements)) @transient lazy val customComparator: Comparator[Any] = { (o1: Any, o2: Any) => { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + 1 + } else if (o2 == null) { + -1 + } else { firstParam.value.set(o2) secondParam.value.set(o1) f.eval(inputRow).asInstanceOf[Int] + } } } - - - for(i <- 0 until arr.numElements()){ - result.update(i, arr.get(i, firstParam.dataType)) - } - val p = sortEval(result, customComparator) + val p = sortEval(arr, customComparator) val t = 2 p } def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](firstParam.dataType) - if (firstParam.dataType!= NullType) { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType!= NullType) { java.util.Arrays.sort(data, comparator) } new GenericArrayData(data.asInstanceOf[Array[Any]]) @@ -356,30 +359,6 @@ case class ArraySorting( override def prettyName: String = "array_new_sort" - // var i = 0 - // while (i < arr.numElements - 1) { - // firstParam.value.set(arr.get(i, firstParam.dataType)) - // secondParam.value.set(i) - // result.update(i, f.eval(inputRow)) - // i += 1 - // } - - // for(i <- 0 until arr.numElements() -1){ - // result.update(i, arr.get(i, firstParam.dataType)) - // } - // for(i <- 0 until arr.numElements()-1) { - // for(j<-0 until arr.numElements - i-1) { - // firstParam.value.set(arr.get(j, firstParam.dataType)) - // secondParam.value.set(arr.get(j + 1, secondParam.dataType)) - // if(f.eval(inputRow) == -1) { - // var temp = arr.get(j, firstParam.dataType) - // arr.update(j, arr.get(j+1, firstParam.dataType)) - // arr.update(j+1 ,temp) - // } - // } - // } - - // result } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d380a1cc7986..fd4bdc67518b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -327,14 +327,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { else -1 }) - spark.udf.register("fStringAsc", (x: String, y: String) => { + spark.udf.register("fString", (x: String, y: String) => { if(x < y) -1 else if(x == y) 0 else 1 }) + spark.udf.register("fStringLength", (x: String, y: String) => { + if(x.length < y.length) 1 + else if(x.length == y.length) 0 + else -1 + }) + + spark.udf.register("fArraylength", (x: Int, y: Int) => { + if(x < y) 1 + else if(x == y) 0 + else -1 + }) + + val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") - val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( df1.selectExpr("array_new_sort(a, (b, i) -> fAsc(b,i))"), @@ -346,19 +358,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq( Row(Seq(1, 2, 2, 3, 5)))) -// checkAnswer( -// df2.selectExpr("array_new_sort(a, (b, c) -> fStringAsc(b,c))"), -// Seq( -// Row(Seq("dc", "bc", "ab")))) - } - - - - + val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") + checkAnswer( + df2.selectExpr("array_new_sort(a, (b, i) -> fString(b,i))"), + Seq( + Row(Seq("dc", "bc", "ab")))) + val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") + checkAnswer( + df3.selectExpr("array_new_sort(a, (b, i) -> fStringLength(b,i))"), + Seq( + Row(Seq("a", "abc", "abcd")))) + val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") + checkAnswer( + df4.selectExpr("array_new_sort(a, (b, i) -> fArraylength(cardinality(b),cardinality(i)))"), + Seq( + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2,3,1), Seq(4, 2, 1, 4))))) + } test("sort_array/array_sort functions") { val df = Seq( From 210baf76163444f54b6cc5e1ae7dbb48fd219776 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Sun, 15 Sep 2019 21:29:07 +0200 Subject: [PATCH 06/27] undo array_sort asc flag --- .../expressions/collectionOperations.scala | 61 ++----------------- .../expressions/higherOrderFunctions.scala | 31 ++-------- .../CollectionExpressionsSuite.scala | 21 +++---- .../org/apache/spark/sql/functions.scala | 24 +------- .../spark/sql/DataFrameFunctionsSuite.scala | 26 +------- 5 files changed, 24 insertions(+), 139 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 795ff67ce124..5314821ea3a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -712,7 +712,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -902,16 +902,14 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** - * Sorts the input array in ascending / descending order according to the natural ordering of + * Sorts the input array in ascending order according to the natural ordering of * the array elements and returns it. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order - according to the natural ordering of the array elements. The elements of the input array must - be orderable. Null elements will be placed at the end of the returned array - in descending / ascending order + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. """, examples = """ Examples: @@ -920,61 +918,14 @@ case class SortArray(base: Expression, ascendingOrder: Expression) """, since = "2.4.0") // scalastyle:on line.size.limit -case class ArraySort(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike { - - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def arrayExpression: Expression = base - override def nullOrder: NullOrder = { - if (ascendingOrder == Literal(true)) NullOrder.Greatest else NullOrder.Least - } - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - val dtSimple = dt.catalogString - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type $dtSimple which is not orderable") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } - - override def nullSafeEval(array: Any, ascending: Any): Any = { - sortEval(array, ascending.asInstanceOf[Boolean]) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (c, order) => sortCodegen(ctx, ev, c, order)) - } - - override def prettyName: String = "array_sort" -} - -//scalastyle:off - -case class ArraySortF(child: Expression, function: Expression) extends UnaryExpression with ArraySortLike { +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { override def dataType: DataType = child.dataType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def arrayExpression: Expression = child override def nullOrder: NullOrder = NullOrder.Greatest - function.asInstanceOf[LambdaFunction].arguments - function.asInstanceOf[LambdaFunction].function - function.asInstanceOf[LambdaFunction].dataType + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 536a5a2c668d..f21534392aab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -23,15 +23,12 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods -import scala.util.Sorting - /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ @@ -272,7 +269,7 @@ case class ArrayTransform( override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval - var result = new GenericArrayData(new Array[Any](arr.numElements)) + val result = new GenericArrayData(new Array[Any](arr.numElements)) var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) @@ -288,17 +285,10 @@ case class ArrayTransform( override def prettyName: String = "transform" } -//scalastyle:off -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ case class ArraySorting( - argument: Expression, - function: Expression) + argument: Expression, + function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { @@ -310,7 +300,8 @@ case class ArraySorting( val ArrayType(elementType, containsNull) = argument.dataType function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) + copy(function = + f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) case _ => copy(function = f(function, (elementType, containsNull) :: Nil)) } @@ -343,10 +334,7 @@ case class ArraySorting( } } - val p = sortEval(arr, customComparator) - val t = 2 - p - + sortEval(arr, customComparator) } def sortEval(array: Any, comparator: Comparator[Any]): Any = { @@ -363,13 +351,6 @@ case class ArraySorting( } -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ -/***************************************/ - /** * Filters entries in a map using the provided function. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b8423bdcc745..603073b40d7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -364,18 +364,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) - checkEvaluation(new ArraySort(a0), Seq(1, 2, 3)) - checkEvaluation(ArraySort(a0, Literal(false)), Seq(3, 2, 1)) - checkEvaluation(new ArraySort(a1), Seq[Integer]()) - checkEvaluation(new ArraySort(a2), Seq("a", "b")) - checkEvaluation(ArraySort(a2, Literal(false)), Seq("b", "a")) - checkEvaluation(new ArraySort(a3), Seq("a", "b", null)) - checkEvaluation(ArraySort(a3, Literal(false)), Seq("b", "a", null)) - checkEvaluation(new ArraySort(a4), Seq(d1, d2)) - checkEvaluation(new ArraySort(a5), Seq(null, null)) - checkEvaluation(new ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) - checkEvaluation(new ArraySort(arrayArray), Seq(aa1, aa2)) - checkEvaluation(new ArraySort(arrayArrayStruct), Seq(aas1, aas2)) + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fdcc84f0878a..84e0eaff2d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3328,33 +3328,13 @@ object functions { } /** - * Sorts the input array for the given column in ascending order, - * according to the natural ordering of the array elements. + * Sorts the input array in ascending order. The elements of the input array must be orderable. * Null elements will be placed at the end of the returned array. * * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = array_sort(e, asc = lit(true)) - - /** - * Sorts the input array for the given column in ascending or descending order, - * according to the natural ordering of the array elements. - * Null elements will be placed at the end of the returned array in descending / ascending order - * @group collection_funcs - * @since 3.0.0 - */ - def array_sort(e: Column, asc: Column): Column = withExpr { ArraySort(e.expr, asc.expr) } - - private def createLambda(f: (Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) - } -//scalastyle:off - def array_sort_f(e: Column, f: (Column, Column) => Column): Column = withExpr { ArraySortF(e.expr, createLambda(f)) } - + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } /** * Remove all elements that equal to element from the given array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fd4bdc67518b..f5fcaaf0a268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import java.util.TimeZone import scala.util.Random + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -32,8 +33,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer - /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ @@ -312,7 +311,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(2)) ) } -//scalastyle:off test("array_sort with lambda functions") { spark.udf.register("fAsc", (x: Int, y: Int) => { @@ -435,15 +433,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) - - checkAnswer( - df.select(array_sort($"a", lit(false)), array_sort($"b", lit(false))), - Seq( - Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq.empty[Int], Seq.empty[String]), - Row(null, null)) - ) - checkAnswer( df.selectExpr("array_sort(a)", "array_sort(b)"), Seq( @@ -452,24 +441,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null)) ) - checkAnswer( - df.selectExpr("array_sort(a, false)", "array_sort(b, false)"), - Seq( - Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq.empty[Int], Seq.empty[String]), - Row(null, null)) - ) - checkAnswer( df2.selectExpr("array_sort(a)"), Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) ) - checkAnswer( - df2.selectExpr("array_sort(a, false)"), - Seq(Row(Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) - ) - assert(intercept[AnalysisException] { df3.selectExpr("array_sort(a)").collect() }.getMessage().contains("only supports array input")) From f7e7d394f5cdc53ee683414462a5614f79027b8d Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 16 Sep 2019 08:21:49 +0200 Subject: [PATCH 07/27] Name and indent refactor --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/higherOrderFunctions.scala | 22 +++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 41 ++++++++----------- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 24000210c2db..351a4b9bf241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -454,7 +454,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArraySorting]("array_new_sort"), + expression[ArraySorting]("array_sort"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index f21534392aab..2479288934f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -285,7 +285,25 @@ case class ArrayTransform( override def prettyName: String = "transform" } - +/** + * Sorts elements in an array using a comparator function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Sorts and returns the array based on the given " + + "comparator function. The comparator will take two nullable arguments " + + "representing two nullable elements of the array." + + "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater" + + "than the second nullable element." + + "If the comparator function returns other values (including NULL)," + + "the query will fail and raise an error", + examples = """ + Examples: + > SELECT _FUNC_(array(5, 6, 1), (x,y) -> f(x,y)); + [1,5,6] + > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x,y)); + ['dc', 'bc', 'ab'] + """, + since = "3.0.0") case class ArraySorting( argument: Expression, function: Expression) @@ -345,7 +363,7 @@ case class ArraySorting( new GenericArrayData(data.asInstanceOf[Array[Any]]) } - override def prettyName: String = "array_new_sort" + override def prettyName: String = "array_sort" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f5fcaaf0a268..88c338f51578 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -311,73 +311,68 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(2)) ) } + test("array_sort with lambda functions") { spark.udf.register("fAsc", (x: Int, y: Int) => { - if(x < y) -1 - else if(x == y) 0 + if (x < y) -1 + else if (x == y) 0 else 1 }) spark.udf.register("fDesc", (x: Int, y: Int) => { - if(x < y) 1 - else if(x == y) 0 + if (x < y) 1 + else if (x == y) 0 else -1 }) spark.udf.register("fString", (x: String, y: String) => { - if(x < y) -1 - else if(x == y) 0 + if (x < y) -1 + else if (x == y) 0 else 1 }) spark.udf.register("fStringLength", (x: String, y: String) => { - if(x.length < y.length) 1 - else if(x.length == y.length) 0 - else -1 - }) - - spark.udf.register("fArraylength", (x: Int, y: Int) => { - if(x < y) 1 - else if(x == y) 0 + if (x.length < y.length) 1 + else if (x.length == y.length) 0 else -1 }) - val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") checkAnswer( - df1.selectExpr("array_new_sort(a, (b, i) -> fAsc(b,i))"), + df1.selectExpr("array_sort(a, (b, i) -> fAsc(b,i))"), Seq( Row(Seq(5, 3, 2, 2, 1)))) checkAnswer( - df1.selectExpr("array_new_sort(a, (b, i) -> fDesc(b,i))"), + df1.selectExpr("array_sort(a, (b, i) -> fDesc(b,i))"), Seq( Row(Seq(1, 2, 2, 3, 5)))) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( - df2.selectExpr("array_new_sort(a, (b, i) -> fString(b,i))"), + df2.selectExpr("array_sort(a, (b, i) -> fString(b,i))"), Seq( Row(Seq("dc", "bc", "ab")))) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") checkAnswer( - df3.selectExpr("array_new_sort(a, (b, i) -> fStringLength(b,i))"), + df3.selectExpr("array_sort(a, (b, i) -> fStringLength(b,i))"), Seq( Row(Seq("a", "abc", "abcd")))) - val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") + val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")) + .toDF("a", "b") checkAnswer( - df4.selectExpr("array_new_sort(a, (b, i) -> fArraylength(cardinality(b),cardinality(i)))"), + df4.selectExpr("array_sort(a, (b, i) -> fDesc(cardinality(b),cardinality(i)))"), Seq( - Row(Seq[Seq[Int]](Seq(1, 2), Seq(2,3,1), Seq(4, 2, 1, 4))))) + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4))))) } - test("sort_array/array_sort functions") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), From e651094d752049c1b46e2ea4179505dc6b3e1fb7 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 16 Sep 2019 21:03:48 +0200 Subject: [PATCH 08/27] Added null handling --- .../expressions/higherOrderFunctions.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2479288934f9..771128c30792 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -312,7 +312,7 @@ case class ArraySorting( @transient lazy val elementType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType - override def dataType: ArrayType = ArrayType(elementType, function.nullable) + override def dataType: ArrayType = ArrayType(elementType, argument.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { val ArrayType(elementType, containsNull) = argument.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 88c338f51578..2bb740a20141 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,24 +341,24 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") checkAnswer( - df1.selectExpr("array_sort(a, (b, i) -> fAsc(b,i))"), + df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( Row(Seq(5, 3, 2, 2, 1)))) checkAnswer( - df1.selectExpr("array_sort(a, (b, i) -> fDesc(b,i))"), + df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), Seq( Row(Seq(1, 2, 2, 3, 5)))) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( - df2.selectExpr("array_sort(a, (b, i) -> fString(b,i))"), + df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( Row(Seq("dc", "bc", "ab")))) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") checkAnswer( - df3.selectExpr("array_sort(a, (b, i) -> fStringLength(b,i))"), + df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), Seq( Row(Seq("a", "abc", "abcd")))) @@ -366,10 +366,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")) .toDF("a", "b") checkAnswer( - df4.selectExpr("array_sort(a, (b, i) -> fDesc(cardinality(b),cardinality(i)))"), + df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4))))) + val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") + checkAnswer( + df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), + Seq( + Row(Seq("dc", "bc", "ab", null)))) + } test("sort_array/array_sort functions") { From b32a9fa415a8c9943a373a12721e88cf402a41d0 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 16 Sep 2019 21:07:40 +0200 Subject: [PATCH 09/27] fix import for scalastyle --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 771128c30792..12acf782edbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -21,6 +21,7 @@ import java.util.Comparator import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ From cc44b909a7df769a7a518a752dc35263a8fe3842 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 16 Sep 2019 22:54:41 +0200 Subject: [PATCH 10/27] added checkInputDataTypes function --- .../expressions/higherOrderFunctions.scala | 20 ++++++++++++++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 5 ++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 12acf782edbd..2469416482a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -294,8 +294,8 @@ case class ArrayTransform( "comparator function. The comparator will take two nullable arguments " + "representing two nullable elements of the array." + "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater" + - "than the second nullable element." + - "If the comparator function returns other values (including NULL)," + + "than the second nullable element. Null elements will be placed at the end of the returned" + + "array. If the comparator function returns other values (including NULL)," + "the query will fail and raise an error", examples = """ Examples: @@ -310,11 +310,25 @@ case class ArraySorting( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - @transient lazy val elementType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType + override def dataType: ArrayType = ArrayType(elementType, argument.nullable) + override def checkInputDataTypes(): TypeCheckResult = { + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + val LambdaFunction(_, arguments, _) = function + if (arguments.size == 2 && function.dataType == IntegerType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Return type of the given function has to be" + + "IntegerType") + } + case failure => failure + } + } + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { val ArrayType(elementType, containsNull) = argument.dataType function match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2bb740a20141..9f8590a13843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -362,13 +362,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq( Row(Seq("a", "abc", "abcd")))) - - val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")) + val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")) .toDF("a", "b") checkAnswer( df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( - Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4))))) + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null)))) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") checkAnswer( From 2285678593b6e35d72408819de9bb805dfa6f397 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Tue, 17 Sep 2019 16:26:57 +0200 Subject: [PATCH 11/27] rename HOF array_sort to sort --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/higherOrderFunctions.scala | 13 ++++-------- .../spark/sql/DataFrameFunctionsSuite.scala | 21 ++++++++++--------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 351a4b9bf241..5afd8a415552 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -454,7 +454,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArraySorting]("array_sort"), + expression[ArraySorting]("sort"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2469416482a1..6385cdd0213f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -310,10 +310,10 @@ case class ArraySorting( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - @transient lazy val elementType: DataType = + @transient lazy val argumenType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType - override def dataType: ArrayType = ArrayType(elementType, argument.nullable) + override def dataType: ArrayType = ArrayType(argumenType, argument.nullable) override def checkInputDataTypes(): TypeCheckResult = { checkArgumentDataTypes() match { @@ -331,13 +331,8 @@ case class ArraySorting( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { val ArrayType(elementType, containsNull) = argument.dataType - function match { - case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) - case _ => - copy(function = f(function, (elementType, containsNull) :: Nil)) - } } @transient lazy val (firstParam, secondParam) = { @@ -371,8 +366,8 @@ case class ArraySorting( } def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType!= NullType) { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumenType) + if (argumenType!= NullType) { java.util.Arrays.sort(data, comparator) } new GenericArrayData(data.asInstanceOf[Array[Any]]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9f8590a13843..1ed1414630ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,37 +341,37 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") checkAnswer( - df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), + df1.selectExpr("sort(a, (x, y) -> fAsc(x, y))"), Seq( Row(Seq(5, 3, 2, 2, 1)))) checkAnswer( - df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), + df1.selectExpr("sort(a, (x, y) -> fDesc(x, y))"), Seq( Row(Seq(1, 2, 2, 3, 5)))) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( - df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), + df2.selectExpr("sort(a, (x, y) -> fString(x, y))"), Seq( Row(Seq("dc", "bc", "ab")))) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") checkAnswer( - df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), + df3.selectExpr("sort(a, (x, y) -> fStringLength(x, y))"), Seq( Row(Seq("a", "abc", "abcd")))) - val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")) - .toDF("a", "b") + val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), null, Array(4, 2, 1, 4), + Array(1, 2)), "x")).toDF("a", "b") checkAnswer( - df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), + df4.selectExpr("sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( - Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null)))) + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null, null)))) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") checkAnswer( - df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), + df5.selectExpr("sort(a, (x, y) -> fString(x, y))"), Seq( Row(Seq("dc", "bc", "ab", null)))) @@ -383,7 +383,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Array.empty[Int], Array.empty[String]), (null, null) ).toDF("a", "b") - checkAnswer( + checkAnswer( df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), @@ -412,6 +412,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null)) ) + val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") checkAnswer( df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), From d35d6cb7dd9d0ae80c006b58b5c1fb3b103a1ba6 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Wed, 18 Sep 2019 09:37:49 +0200 Subject: [PATCH 12/27] Remove ArraySort and unifiyng it in new ArraySort HOF --- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../expressions/collectionOperations.scala | 48 ------------ .../expressions/higherOrderFunctions.scala | 78 ++++++++++++++----- .../CollectionExpressionsSuite.scala | 10 --- .../org/apache/spark/sql/functions.scala | 9 --- .../spark/sql/DataFrameFunctionsSuite.scala | 25 ++---- 6 files changed, 68 insertions(+), 105 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5afd8a415552..15372ae207f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -426,7 +426,6 @@ object FunctionRegistry { expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), - expression[ArraySort]("array_sort"), expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), @@ -454,7 +453,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArraySorting]("sort"), + expression[ArraySort]("array_sort"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5314821ea3a5..d2c0cf5b264f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -900,54 +900,6 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } - -/** - * Sorts the input array in ascending order according to the natural ordering of - * the array elements and returns it. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must - be orderable. Null elements will be placed at the end of the returned array. - """, - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a","b","c","d",null] - """, - since = "2.4.0") -// scalastyle:on line.size.limit -case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { - - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - - override def arrayExpression: Expression = child - override def nullOrder: NullOrder = NullOrder.Greatest - - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess - case ArrayType(dt, _) => - val dtSimple = dt.catalogString - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type $dtSimple which is not orderable") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } - - override def nullSafeEval(array: Any): Any = { - sortEval(array, true) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) - } - - override def prettyName: String = "array_sort" -} - /** * Returns a random permutation of the given array. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 6385cdd0213f..ca1ae535c262 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -296,7 +296,8 @@ case class ArrayTransform( "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater" + "than the second nullable element. Null elements will be placed at the end of the returned" + "array. If the comparator function returns other values (including NULL)," + - "the query will fail and raise an error", + "the query will fail and raise an error. By the default it will sort the array in " + + "ascending mode", examples = """ Examples: > SELECT _FUNC_(array(5, 6, 1), (x,y) -> f(x,y)); @@ -305,31 +306,68 @@ case class ArrayTransform( ['dc', 'bc', 'ab'] """, since = "3.0.0") -case class ArraySorting( +case class ArraySort( argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - @transient lazy val argumenType: DataType = + def this(argument: Expression) = this(argument, Literal(true)) + + @transient lazy val argumentsType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType - override def dataType: ArrayType = ArrayType(argumenType, argument.nullable) + @transient lazy val lt: Comparator[Any] = { + val ordering = argument.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } + + (o1: Any, o2: Any) => { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + 1 + } else if (o2 == null) { + -1 + } else { + ordering.compare(o1, o2) + } + } + } + + override def dataType: ArrayType = ArrayType(argumentsType, argument.nullable) override def checkInputDataTypes(): TypeCheckResult = { - checkArgumentDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - val LambdaFunction(_, arguments, _) = function - if (arguments.size == 2 && function.dataType == IntegerType) { + + if (function.dataType == BooleanType) { + argument.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("Return type of the given function has to be" + - "IntegerType") - } - case failure => failure + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + } else { + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + val LambdaFunction(_, arguments, _) = function + if (arguments.size == 2 && function.dataType == IntegerType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Return type of the given function has to be" + + " IntegerType") + } + + case failure => failure + } } } - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySorting = { + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySort = { val ArrayType(elementType, containsNull) = argument.dataType copy(function = f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) @@ -341,7 +379,6 @@ case class ArraySorting( (firstParam, secondParam) } - override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval @@ -362,12 +399,16 @@ case class ArraySorting( } } - sortEval(arr, customComparator) + if (function.dataType == BooleanType) { + sortEval(arr, lt) + } else { + sortEval(arr, customComparator) + } } def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumenType) - if (argumenType!= NullType) { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumentsType) + if (argumentType!= NullType) { java.util.Arrays.sort(data, comparator) } new GenericArrayData(data.asInstanceOf[Array[Any]]) @@ -375,7 +416,6 @@ case class ArraySorting( override def prettyName: String = "array_sort" - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 603073b40d7a..79e7904ec588 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -363,16 +363,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) - - checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) - checkEvaluation(ArraySort(a1), Seq[Integer]()) - checkEvaluation(ArraySort(a2), Seq("a", "b")) - checkEvaluation(ArraySort(a3), Seq("a", "b", null)) - checkEvaluation(ArraySort(a4), Seq(d1, d2)) - checkEvaluation(ArraySort(a5), Seq(null, null)) - checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) - checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) - checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 84e0eaff2d42..352e18808ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3327,15 +3327,6 @@ object functions { ElementAt(column.expr, lit(value).expr) } - /** - * Sorts the input array in ascending order. The elements of the input array must be orderable. - * Null elements will be placed at the end of the returned array. - * - * @group collection_funcs - * @since 2.4.0 - */ - def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } - /** * Remove all elements that equal to element from the given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1ed1414630ad..5d5a015ec561 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -339,39 +339,38 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }) val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") - checkAnswer( - df1.selectExpr("sort(a, (x, y) -> fAsc(x, y))"), + df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( Row(Seq(5, 3, 2, 2, 1)))) checkAnswer( - df1.selectExpr("sort(a, (x, y) -> fDesc(x, y))"), + df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), Seq( Row(Seq(1, 2, 2, 3, 5)))) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( - df2.selectExpr("sort(a, (x, y) -> fString(x, y))"), + df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( Row(Seq("dc", "bc", "ab")))) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") checkAnswer( - df3.selectExpr("sort(a, (x, y) -> fStringLength(x, y))"), + df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), Seq( Row(Seq("a", "abc", "abcd")))) val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), null, Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") checkAnswer( - df4.selectExpr("sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), + df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null, null)))) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") checkAnswer( - df5.selectExpr("sort(a, (x, y) -> fString(x, y))"), + df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( Row(Seq("dc", "bc", "ab", null)))) @@ -412,7 +411,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null)) ) - val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") checkAnswer( df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), @@ -427,13 +425,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) - checkAnswer( - df.select(array_sort($"a"), array_sort($"b")), - Seq( - Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq.empty[Int], Seq.empty[String]), - Row(null, null)) - ) checkAnswer( df.selectExpr("array_sort(a)", "array_sort(b)"), Seq( @@ -447,9 +438,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) ) - assert(intercept[AnalysisException] { + assert(intercept[AnalysisException] { df3.selectExpr("array_sort(a)").collect() - }.getMessage().contains("only supports array input")) + }.getMessage().contains("argument 1 requires array type, however, '`a`' is of string type")) } def testSizeOfArray(sizeOfNull: Any): Unit = { From ad045997aa573fb5fce321953e31cb92e56632ee Mon Sep 17 00:00:00 2001 From: gschiavon Date: Wed, 18 Sep 2019 15:09:29 +0200 Subject: [PATCH 13/27] added sortEval in ArraySortLike --- .../expressions/collectionOperations.scala | 12 +++++- .../expressions/higherOrderFunctions.scala | 38 +++---------------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d2c0cf5b264f..8d9c4a520b46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -712,7 +712,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient private lazy val lt: Comparator[Any] = { + @transient lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -732,7 +732,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient private lazy val gt: Comparator[Any] = { + @transient lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -765,6 +765,14 @@ trait ArraySortLike extends ExpectsInputTypes { new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortEval(array: Any, comparator: Comparator[Any]): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType!= NullType) { + java.util.Arrays.sort(data, comparator) + } + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ca1ae535c262..05f3a8c9e5f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -309,32 +310,16 @@ case class ArrayTransform( case class ArraySort( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with ArraySortLike with CodegenFallback { def this(argument: Expression) = this(argument, Literal(true)) @transient lazy val argumentsType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType - @transient lazy val lt: Comparator[Any] = { - val ordering = argument.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] - } + override protected def arrayExpression: Expression = argument - (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - ordering.compare(o1, o2) - } - } - } + override protected def nullOrder: NullOrder = NullOrder.Greatest override def dataType: ArrayType = ArrayType(argumentsType, argument.nullable) @@ -398,20 +383,7 @@ case class ArraySort( } } } - - if (function.dataType == BooleanType) { - sortEval(arr, lt) - } else { - sortEval(arr, customComparator) - } - } - - def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumentsType) - if (argumentType!= NullType) { - java.util.Arrays.sort(data, comparator) - } - new GenericArrayData(data.asInstanceOf[Array[Any]]) + sortEval(arr, if (function.dataType == BooleanType) lt else customComparator) } override def prettyName: String = "array_sort" From 7ad574a4df899be1f66be70ef955af83b8440ac0 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Wed, 18 Sep 2019 15:55:31 +0200 Subject: [PATCH 14/27] add array_sort in functions for scala API --- .../src/main/scala/org/apache/spark/sql/functions.scala | 9 +++++++++ .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 352e18808ff3..0cc6d44da4d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3327,6 +3327,15 @@ object functions { ElementAt(column.expr, lit(value).expr) } + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr, lit(true).expr) } + /** * Remove all elements that equal to element from the given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d5a015ec561..df22fb306cbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -425,6 +425,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) checkAnswer( df.selectExpr("array_sort(a)", "array_sort(b)"), Seq( From d75a6715e59570e632f2392325fa3539a75919b8 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Wed, 18 Sep 2019 23:12:26 +0200 Subject: [PATCH 15/27] Refactor changes --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 2 +- .../catalyst/expressions/collectionOperations.scala | 4 ++-- .../catalyst/expressions/higherOrderFunctions.scala | 10 ++++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 15372ae207f3..d5728b902757 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -426,6 +426,7 @@ object FunctionRegistry { expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), @@ -453,7 +454,6 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArraySort]("array_sort"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8d9c4a520b46..6cd556b4fd5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -767,8 +767,8 @@ trait ArraySortLike extends ExpectsInputTypes { def sortEval(array: Any, comparator: Comparator[Any]): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType!= NullType) { - java.util.Arrays.sort(data, comparator) + if (elementType != NullType) { + java.util.Arrays.sort(data, comparator) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 05f3a8c9e5f6..1b7694109841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -301,10 +301,12 @@ case class ArrayTransform( "ascending mode", examples = """ Examples: - > SELECT _FUNC_(array(5, 6, 1), (x,y) -> f(x,y)); + > SELECT _FUNC_(array(5, 6, 1), (x, y) -> f(x, y)); [1,5,6] - > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x,y)); + > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x, y)); ['dc', 'bc', 'ab'] + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + | ["d", "c", "b", "a", null] """, since = "3.0.0") case class ArraySort( @@ -343,8 +345,8 @@ case class ArraySort( if (arguments.size == 2 && function.dataType == IntegerType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure("Return type of the given function has to be" + - " IntegerType") + TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + + "IntegerType") } case failure => failure From 0c65ff89b55ed5e238e3dbc330bdbe2632184286 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Mon, 23 Sep 2019 20:20:29 +0200 Subject: [PATCH 16/27] Consistency with array_sort constructor --- .../expressions/higherOrderFunctions.scala | 46 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 2 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 1b7694109841..9fb063b87d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -291,12 +291,14 @@ case class ArrayTransform( * Sorts elements in an array using a comparator function. */ @ExpressionDescription( - usage = "_FUNC_(expr, func) - Sorts and returns the array based on the given " + + usage = "_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the " + + "input array must be orderable. Null elements will be placed at the end of the returned " + + "array. Since 3.0.0 also sorts and returns the array based on the given " + "comparator function. The comparator will take two nullable arguments " + "representing two nullable elements of the array." + - "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater" + - "than the second nullable element. Null elements will be placed at the end of the returned" + - "array. If the comparator function returns other values (including NULL)," + + "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater " + + "than the second nullable element. Null elements will be placed at the end of the returned " + + "array. If the comparator function returns other values (including NULL), " + "the query will fail and raise an error. By the default it will sort the array in " + "ascending mode", examples = """ @@ -308,7 +310,7 @@ case class ArrayTransform( > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); | ["d", "c", "b", "a", null] """, - since = "3.0.0") + since = "2.4.0") case class ArraySort( argument: Expression, function: Expression) @@ -366,26 +368,26 @@ case class ArraySort( (firstParam, secondParam) } - override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { - val arr = argumentValue.asInstanceOf[ArrayData] + def comparator(inputRow: InternalRow): Comparator[Any] = { val f = functionForEval - - @transient lazy val customComparator: Comparator[Any] = { - (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - firstParam.value.set(o2) - secondParam.value.set(o1) - f.eval(inputRow).asInstanceOf[Int] - } + (o1: Any, o2: Any) => { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + 1 + } else if (o2 == null) { + -1 + } else { + firstParam.value.set(o2) + secondParam.value.set(o1) + f.eval(inputRow).asInstanceOf[Int] } } - sortEval(arr, if (function.dataType == BooleanType) lt else customComparator) + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] + sortEval(arr, if (function.dataType == BooleanType) lt else comparator(inputRow)) } override def prettyName: String = "array_sort" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0cc6d44da4d1..a65906d47741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3334,7 +3334,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = withExpr { ArraySort(e.expr, lit(true).expr) } + def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) } /** * Remove all elements that equal to element from the given array. From 83ddb8bb528d8290afcc9ad6502f21b36720a5ae Mon Sep 17 00:00:00 2001 From: gschiavon Date: Tue, 24 Sep 2019 07:19:38 +0200 Subject: [PATCH 17/27] fix comment in ArraySort --- .../sql/catalyst/expressions/higherOrderFunctions.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 9fb063b87d71..21791d62fe1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -297,10 +297,8 @@ case class ArrayTransform( "comparator function. The comparator will take two nullable arguments " + "representing two nullable elements of the array." + "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater " + - "than the second nullable element. Null elements will be placed at the end of the returned " + - "array. If the comparator function returns other values (including NULL), " + - "the query will fail and raise an error. By the default it will sort the array in " + - "ascending mode", + "than the second nullable element. If the comparator function returns other " + + "values (including NULL), the query will fail and raise an error.", examples = """ Examples: > SELECT _FUNC_(array(5, 6, 1), (x, y) -> f(x, y)); From adf0e0e3b2fb75ae7ff67cd83a87b62c8eb105e3 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Fri, 27 Sep 2019 07:40:10 +0200 Subject: [PATCH 18/27] Indentation changes --- .../expressions/collectionOperations.scala | 2 +- .../expressions/higherOrderFunctions.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 19 ++++++++++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6cd556b4fd5f..2109e34a3d69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -769,7 +769,7 @@ trait ArraySortLike extends ExpectsInputTypes { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { java.util.Arrays.sort(data, comparator) - } + } new GenericArrayData(data.asInstanceOf[Array[Any]]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 21791d62fe1e..4065ffe38e7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -304,7 +304,7 @@ case class ArrayTransform( > SELECT _FUNC_(array(5, 6, 1), (x, y) -> f(x, y)); [1,5,6] > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x, y)); - ['dc', 'bc', 'ab'] + ["dc", "bc", "ab"] > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); | ["d", "c", "b", "a", null] """, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df22fb306cbe..6cc4e6c3122a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -342,38 +342,43 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( - Row(Seq(5, 3, 2, 2, 1)))) + Row(Seq(5, 3, 2, 2, 1))) + ) checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), Seq( - Row(Seq(1, 2, 2, 3, 5)))) + Row(Seq(1, 2, 2, 3, 5))) + ) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") checkAnswer( df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( - Row(Seq("dc", "bc", "ab")))) + Row(Seq("dc", "bc", "ab"))) + ) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") checkAnswer( df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), Seq( - Row(Seq("a", "abc", "abcd")))) + Row(Seq("a", "abc", "abcd"))) + ) val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), null, Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") checkAnswer( df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( - Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null, null)))) + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null, null))) + ) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") checkAnswer( df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( - Row(Seq("dc", "bc", "ab", null)))) - + Row(Seq("dc", "bc", "ab", null))) + ) } test("sort_array/array_sort functions") { From 2224354f5bdf10c6f4ccda7603b5c495b96e9e28 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Sat, 28 Sep 2019 19:03:51 +0200 Subject: [PATCH 19/27] change dataType from ArraySort --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 4065ffe38e7b..0edd7ac569d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -323,7 +323,7 @@ case class ArraySort( override protected def nullOrder: NullOrder = NullOrder.Greatest - override def dataType: ArrayType = ArrayType(argumentsType, argument.nullable) + override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType] override def checkInputDataTypes(): TypeCheckResult = { From 92586ccdf5c0addbfd1cbe47eb2b79982d1d546c Mon Sep 17 00:00:00 2001 From: gschiavon Date: Tue, 22 Oct 2019 08:52:11 +0200 Subject: [PATCH 20/27] Unregistered Udfs, null handle in comparators --- .../expressions/higherOrderFunctions.scala | 10 +-------- .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++++++++++++++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0edd7ac569d1..fc28136c31cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -306,7 +306,7 @@ case class ArrayTransform( > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x, y)); ["dc", "bc", "ab"] > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - | ["d", "c", "b", "a", null] + ["d", "c", "b", "a", null] """, since = "2.4.0") case class ArraySort( @@ -369,17 +369,9 @@ case class ArraySort( def comparator(inputRow: InternalRow): Comparator[Any] = { val f = functionForEval (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { firstParam.value.set(o2) secondParam.value.set(o1) f.eval(inputRow).asInstanceOf[Int] - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cc4e6c3122a..4d2f9a2d2261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -327,15 +327,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }) spark.udf.register("fString", (x: String, y: String) => { - if (x < y) -1 + if (x == null && y == null) 0 + else if (x == null) -1 + else if (y == null) 1 + else if (x < y) -1 else if (x == y) 0 else 1 }) spark.udf.register("fStringLength", (x: String, y: String) => { - if (x.length < y.length) 1 + if (x == null && y == null) 0 + else if (x == null) -1 + else if (y == null) 1 + else if (x.length < y.length) 1 else if (x.length == y.length) 0 else -1 + }) val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") @@ -365,12 +372,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("a", "abc", "abcd"))) ) - val df4 = Seq((Array[Array[Int]](null, Array(2, 3, 1), null, Array(4, 2, 1, 4), + val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") checkAnswer( df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), Seq( - Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4), null, null))) + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4)))) ) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") @@ -379,6 +386,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq( Row(Seq("dc", "bc", "ab", null))) ) + + spark.sql("drop temporary function fAsc") + spark.sql("drop temporary function fDesc") + spark.sql("drop temporary function fString") + spark.sql("drop temporary function fStringLength") } test("sort_array/array_sort functions") { @@ -387,7 +399,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Array.empty[Int], Array.empty[String]), (null, null) ).toDF("a", "b") - checkAnswer( + checkAnswer( df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), From c836dae6d9af290ab1d64b4222e23f2d3ff611b9 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Fri, 25 Oct 2019 11:50:27 +0200 Subject: [PATCH 21/27] Remove ArraySortLike from array_sort --- .../expressions/collectionOperations.scala | 8 ----- .../expressions/higherOrderFunctions.scala | 34 ++++++++++++++++--- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2109e34a3d69..69629bc39616 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -765,14 +765,6 @@ trait ArraySortLike extends ExpectsInputTypes { new GenericArrayData(data.asInstanceOf[Array[Any]]) } - def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType != NullType) { - java.util.Arrays.sort(data, comparator) - } - new GenericArrayData(data.asInstanceOf[Array[Any]]) - } - def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index fc28136c31cd..ed0d121a7435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -312,17 +312,13 @@ case class ArrayTransform( case class ArraySort( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with ArraySortLike with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { def this(argument: Expression) = this(argument, Literal(true)) @transient lazy val argumentsType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType - override protected def arrayExpression: Expression = argument - - override protected def nullOrder: NullOrder = NullOrder.Greatest - override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType] override def checkInputDataTypes(): TypeCheckResult = { @@ -380,6 +376,34 @@ case class ArraySort( sortEval(arr, if (function.dataType == BooleanType) lt else comparator(inputRow)) } + def sortEval(array: Any, comparator: Comparator[Any]): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumentsType) + if (argumentsType != NullType) { + java.util.Arrays.sort(data, comparator) + } + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + @transient lazy val lt: Comparator[Any] = { + val ordering = argument.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } + + (o1: Any, o2: Any) => { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + 1 + } else if (o2 == null) { + -1 + } else { + ordering.compare(o1, o2) + } + } + } + override def prettyName: String = "array_sort" } From f7a93c56dcc8a74dbb10ab7f4548d8c775ec7870 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 4 Nov 2019 15:14:57 -0800 Subject: [PATCH 22/27] Fix ArraySort to support comparator function. --- .../expressions/collectionOperations.scala | 4 +- .../expressions/higherOrderFunctions.scala | 112 ++++++++---------- .../HigherOrderFunctionsSuite.scala | 50 ++++++++ 3 files changed, 99 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 69629bc39616..d2c0cf5b264f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -712,7 +712,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -732,7 +732,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ed0d121a7435..eb70458bc566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -314,39 +313,33 @@ case class ArraySort( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - def this(argument: Expression) = this(argument, Literal(true)) + def this(argument: Expression) = this(argument, ArraySort.defaultComparator) - @transient lazy val argumentsType: DataType = + @transient lazy val elementType: DataType = argument.dataType.asInstanceOf[ArrayType].elementType override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType] override def checkInputDataTypes(): TypeCheckResult = { - - if (function.dataType == BooleanType) { - argument.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess - case ArrayType(dt, _) => - val dtSimple = dt.catalogString - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type $dtSimple which is not orderable") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } - } else { - checkArgumentDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - val LambdaFunction(_, arguments, _) = function - if (arguments.size == 2 && function.dataType == IntegerType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + - "IntegerType") - } - - case failure => failure - } + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + argument.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + if (function.dataType == IntegerType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + + "IntegerType") + } + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not " + + "orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + case failure => failure } } @@ -356,58 +349,47 @@ case class ArraySort( f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) } - @transient lazy val (firstParam, secondParam) = { - val LambdaFunction(_, (firstParam: NamedLambdaVariable) +: tail, _) = function - val secondParam = tail.head.asInstanceOf[NamedLambdaVariable] - (firstParam, secondParam) - } + @transient lazy val LambdaFunction(_, + Seq(firstElemVar: NamedLambdaVariable, secondElemVar: NamedLambdaVariable), _) = function def comparator(inputRow: InternalRow): Comparator[Any] = { val f = functionForEval (o1: Any, o2: Any) => { - firstParam.value.set(o2) - secondParam.value.set(o1) - f.eval(inputRow).asInstanceOf[Int] + firstElemVar.value.set(o1) + secondElemVar.value.set(o2) + f.eval(inputRow).asInstanceOf[Int] } } override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { - val arr = argumentValue.asInstanceOf[ArrayData] - sortEval(arr, if (function.dataType == BooleanType) lt else comparator(inputRow)) - } - - def sortEval(array: Any, comparator: Comparator[Any]): Any = { - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](argumentsType) - if (argumentsType != NullType) { - java.util.Arrays.sort(data, comparator) + val arr = argumentValue.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType != NullType) { + java.util.Arrays.sort(arr, comparator(inputRow)) } - new GenericArrayData(data.asInstanceOf[Array[Any]]) + new GenericArrayData(arr.asInstanceOf[Array[Any]]) } - @transient lazy val lt: Comparator[Any] = { - val ordering = argument.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] - } + override def prettyName: String = "array_sort" +} - (o1: Any, o2: Any) => { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - ordering.compare(o1, o2) - } - } - } +object ArraySort { - override def prettyName: String = "array_sort" + def comparator(left: Expression, right: Expression): Expression = { + val lit0 = Literal(0) + val lit1 = Literal(1) + val litm1 = Literal(-1) -} + If(And(IsNull(left), IsNull(right)), lit0, + If(IsNull(left), lit1, If(IsNull(right), litm1, + If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0))))) + } + val defaultComparator: LambdaFunction = { + val left = UnresolvedNamedLambdaVariable(Seq("left")) + val right = UnresolvedNamedLambdaVariable(Seq("right")) + LambdaFunction(comparator(left, right), Seq(left, right)) + } +} /** * Filters entries in a map using the provided function. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index b83d03025d21..9a613cfe61d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -84,6 +84,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) } + def arraySort(expr: Expression): Expression = { + arraySort(expr, ArraySort.comparator) + } + + def arraySort(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArraySort(expr, createLambda(et, cn, et, cn, f)).bind(validateBinding) + } + def filter(expr: Expression, f: Expression => Expression): Expression = { val ArrayType(et, cn) = expr.dataType ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) @@ -162,6 +171,47 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq("[1, 3, 5]", null, "[4, 6]")) } + test("ArraySort") { + val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(arraySort(a0), Seq(1, 2, 3)) + checkEvaluation(arraySort(a1), Seq[Integer]()) + checkEvaluation(arraySort(a2), Seq("a", "b")) + checkEvaluation(arraySort(a3), Seq("a", "b", null)) + checkEvaluation(arraySort(a4), Seq(d1, d2)) + checkEvaluation(arraySort(a5), Seq(null, null)) + checkEvaluation(arraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(arraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(arraySort(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(arraySort(a0, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(3, 2, 1)) + checkEvaluation(arraySort(a3, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(null, "b", "a")) + checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(d2, d1)) + } + test("MapFilter") { def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { val MapType(kt, vt, vcn) = expr.dataType From a3451420b51e1dcd38d122ba4f527ee422ffc080 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Thu, 7 Nov 2019 14:19:51 +0100 Subject: [PATCH 23/27] fixed test after changing comparator order --- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4d2f9a2d2261..8fcbd776c7be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -314,13 +314,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("array_sort with lambda functions") { - spark.udf.register("fAsc", (x: Int, y: Int) => { + spark.udf.register("fDesc", (x: Int, y: Int) => { if (x < y) -1 else if (x == y) 0 else 1 }) - spark.udf.register("fDesc", (x: Int, y: Int) => { + spark.udf.register("fAsc", (x: Int, y: Int) => { if (x < y) 1 else if (x == y) 0 else -1 @@ -328,20 +328,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { spark.udf.register("fString", (x: String, y: String) => { if (x == null && y == null) 0 - else if (x == null) -1 - else if (y == null) 1 - else if (x < y) -1 + else if (x == null) 1 + else if (y == null) -1 + else if (x < y) 1 else if (x == y) 0 - else 1 + else -1 }) spark.udf.register("fStringLength", (x: String, y: String) => { if (x == null && y == null) 0 - else if (x == null) -1 - else if (y == null) 1 - else if (x.length < y.length) 1 + else if (x == null) 1 + else if (y == null) -1 + else if (x.length < y.length) -1 else if (x.length == y.length) 0 - else -1 + else 1 }) From 3fe059a8cdeb19974467fe07df125723e06d91d9 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Fri, 8 Nov 2019 08:15:45 +0100 Subject: [PATCH 24/27] fixing ArraySort Query Example --- .../expressions/higherOrderFunctions.scala | 28 ++++++++++--------- .../HigherOrderFunctionsSuite.scala | 2 ++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index eb70458bc566..a59e30e0f660 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -289,25 +289,28 @@ case class ArrayTransform( /** * Sorts elements in an array using a comparator function. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the " + - "input array must be orderable. Null elements will be placed at the end of the returned " + - "array. Since 3.0.0 also sorts and returns the array based on the given " + - "comparator function. The comparator will take two nullable arguments " + - "representing two nullable elements of the array." + - "It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater " + - "than the second nullable element. If the comparator function returns other " + - "values (including NULL), the query will fail and raise an error.", + usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the + input array must be orderable. Null elements will be placed at the end of the returned + array. Since 3.0.0 also sorts and returns the array based on the given + comparator function. The comparator will take two nullable arguments + representing two nullable elements of the array. + It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater + than the second nullable element. If the comparator function returns other + values (including NULL), the query will fail and raise an error. + """, examples = """ Examples: - > SELECT _FUNC_(array(5, 6, 1), (x, y) -> f(x, y)); + > SELECT _FUNC_(array(5, 6, 1), (left, right) -> If(And(IsNull(left), IsNull(right)), 0, If(IsNull(left), 1, If(IsNull(right), -1, If(left < right, -1, If(left > right, 1, 0)))))); [1,5,6] - > SELECT _FUNC_(array('bc', 'ab', 'dc'), (x, y) -> f(x, y)); - ["dc", "bc", "ab"] + > SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> If(And(IsNull(left), IsNull(right)), 0, If(IsNull(left), 1, If(IsNull(right), -1, If(left < right, -1, If(left > right, 1, 0)))))); + ["dc","bc","ab"] > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["d", "c", "b", "a", null] + ["a","b","c","d",null] """, since = "2.4.0") +// scalastyle:on line.size.limit case class ArraySort( argument: Expression, function: Expression) @@ -319,7 +322,6 @@ case class ArraySort( argument.dataType.asInstanceOf[ArrayType].elementType override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType] - override def checkInputDataTypes(): TypeCheckResult = { checkArgumentDataTypes() match { case TypeCheckResult.TypeCheckSuccess => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 9a613cfe61d0..30d74bf3ec11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -210,6 +210,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(null, "b", "a")) checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), Seq(d2, d1)) + + } test("MapFilter") { From 53b563ec365bdbcc8a95a87a5885808684a7a2b3 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Wed, 13 Nov 2019 09:05:30 +0100 Subject: [PATCH 25/27] Fix lambda function names and readabilty of query examples --- .../catalyst/expressions/higherOrderFunctions.scala | 12 ++++++------ .../expressions/HigherOrderFunctionsSuite.scala | 2 -- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 11 +++++------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a59e30e0f660..471d3407d749 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -294,17 +294,17 @@ case class ArrayTransform( usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. Since 3.0.0 also sorts and returns the array based on the given - comparator function. The comparator will take two nullable arguments - representing two nullable elements of the array. - It returns -1, 0, or 1 as the first nullable element is less than, equal to, or greater - than the second nullable element. If the comparator function returns other + comparator function. The comparator will take two arguments + representing two elements of the array. + It returns -1, 0, or 1 as the first element is less than, equal to, or greater + than the second element. If the comparator function returns other values (including NULL), the query will fail and raise an error. """, examples = """ Examples: - > SELECT _FUNC_(array(5, 6, 1), (left, right) -> If(And(IsNull(left), IsNull(right)), 0, If(IsNull(left), 1, If(IsNull(right), -1, If(left < right, -1, If(left > right, 1, 0)))))); + > SELECT _FUNC_(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end); [1,5,6] - > SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> If(And(IsNull(left), IsNull(right)), 0, If(IsNull(left), 1, If(IsNull(right), -1, If(left < right, -1, If(left > right, 1, 0)))))); + > SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> case when left is null and right is null then 0 when left is null then -1 when right is null then 1 when left < right then 1 when left > right then -1 else 0 end); ["dc","bc","ab"] > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); ["a","b","c","d",null] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 30d74bf3ec11..9a613cfe61d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -210,8 +210,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(null, "b", "a")) checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), Seq(d2, d1)) - - } test("MapFilter") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8fcbd776c7be..0b8a0ba94f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -314,13 +314,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("array_sort with lambda functions") { - spark.udf.register("fDesc", (x: Int, y: Int) => { + spark.udf.register("fAsc", (x: Int, y: Int) => { if (x < y) -1 else if (x == y) 0 else 1 }) - spark.udf.register("fAsc", (x: Int, y: Int) => { + spark.udf.register("fDesc", (x: Int, y: Int) => { if (x < y) 1 else if (x == y) 0 else -1 @@ -342,20 +342,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { else if (x.length < y.length) -1 else if (x.length == y.length) 0 else 1 - }) val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( - Row(Seq(5, 3, 2, 2, 1))) + Row(Seq(1, 2, 2, 3, 5))) ) checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), Seq( - Row(Seq(1, 2, 2, 3, 5))) + Row(Seq(5, 3, 2, 2, 1))) ) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") @@ -375,7 +374,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") checkAnswer( - df4.selectExpr("array_sort(a, (x, y) -> fDesc(cardinality(x), cardinality(y)))"), + df4.selectExpr("array_sort(a, (x, y) -> fAsc(cardinality(x), cardinality(y)))"), Seq( Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4)))) ) From 7c38b8a11ae22f32603084fe0e0992e4303a944f Mon Sep 17 00:00:00 2001 From: gschiavon Date: Thu, 14 Nov 2019 04:57:01 +0100 Subject: [PATCH 26/27] revert indent --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0b8a0ba94f04..8de0a834b92f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -461,7 +461,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) ) - assert(intercept[AnalysisException] { + assert(intercept[AnalysisException] { df3.selectExpr("array_sort(a)").collect() }.getMessage().contains("argument 1 requires array type, however, '`a`' is of string type")) } From ef28d4fd63d6fa71081035b45293a90fc1575687 Mon Sep 17 00:00:00 2001 From: gschiavon Date: Sun, 17 Nov 2019 20:07:11 +0100 Subject: [PATCH 27/27] Changed ArraySort description --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 471d3407d749..c46bc97a184e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -293,12 +293,12 @@ case class ArrayTransform( @ExpressionDescription( usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned - array. Since 3.0.0 also sorts and returns the array based on the given + array. Since 3.0.0 this function also sorts and returns the array based on the given comparator function. The comparator will take two arguments representing two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other - values (including NULL), the query will fail and raise an error. + values (including null), the function will fail and raise an error. """, examples = """ Examples: