diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5314821ea3a5..d2c0cf5b264f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -900,54 +900,6 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } - -/** - * Sorts the input array in ascending order according to the natural ordering of - * the array elements and returns it. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must - be orderable. Null elements will be placed at the end of the returned array. - """, - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a","b","c","d",null] - """, - since = "2.4.0") -// scalastyle:on line.size.limit -case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { - - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - - override def arrayExpression: Expression = child - override def nullOrder: NullOrder = NullOrder.Greatest - - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess - case ArrayType(dt, _) => - val dtSimple = dt.catalogString - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type $dtSimple which is not orderable") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } - - override def nullSafeEval(array: Any): Any = { - sortEval(array, true) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) - } - - override def prettyName: String = "array_sort" -} - /** * Returns a random permutation of the given array. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ed26bb375de2..c46bc97a184e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Comparator import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable @@ -285,6 +286,113 @@ case class ArrayTransform( override def prettyName: String = "transform" } +/** + * Sorts elements in an array using a comparator function. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the + input array must be orderable. Null elements will be placed at the end of the returned + array. Since 3.0.0 this function also sorts and returns the array based on the given + comparator function. The comparator will take two arguments + representing two elements of the array. + It returns -1, 0, or 1 as the first element is less than, equal to, or greater + than the second element. If the comparator function returns other + values (including null), the function will fail and raise an error. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end); + [1,5,6] + > SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> case when left is null and right is null then 0 when left is null then -1 when right is null then 1 when left < right then 1 when left > right then -1 else 0 end); + ["dc","bc","ab"] + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort( + argument: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + def this(argument: Expression) = this(argument, ArraySort.defaultComparator) + + @transient lazy val elementType: DataType = + argument.dataType.asInstanceOf[ArrayType].elementType + + override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType] + override def checkInputDataTypes(): TypeCheckResult = { + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + argument.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + if (function.dataType == IntegerType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + + "IntegerType") + } + case ArrayType(dt, _) => + val dtSimple = dt.catalogString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not " + + "orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + case failure => failure + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySort = { + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = + f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, + Seq(firstElemVar: NamedLambdaVariable, secondElemVar: NamedLambdaVariable), _) = function + + def comparator(inputRow: InternalRow): Comparator[Any] = { + val f = functionForEval + (o1: Any, o2: Any) => { + firstElemVar.value.set(o1) + secondElemVar.value.set(o2) + f.eval(inputRow).asInstanceOf[Int] + } + } + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType != NullType) { + java.util.Arrays.sort(arr, comparator(inputRow)) + } + new GenericArrayData(arr.asInstanceOf[Array[Any]]) + } + + override def prettyName: String = "array_sort" +} + +object ArraySort { + + def comparator(left: Expression, right: Expression): Expression = { + val lit0 = Literal(0) + val lit1 = Literal(1) + val litm1 = Literal(-1) + + If(And(IsNull(left), IsNull(right)), lit0, + If(IsNull(left), lit1, If(IsNull(right), litm1, + If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0))))) + } + + val defaultComparator: LambdaFunction = { + val left = UnresolvedNamedLambdaVariable(Seq("left")) + val right = UnresolvedNamedLambdaVariable(Seq("right")) + LambdaFunction(comparator(left, right), Seq(left, right)) + } +} + /** * Filters entries in a map using the provided function. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 603073b40d7a..79e7904ec588 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -363,16 +363,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) - - checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) - checkEvaluation(ArraySort(a1), Seq[Integer]()) - checkEvaluation(ArraySort(a2), Seq("a", "b")) - checkEvaluation(ArraySort(a3), Seq("a", "b", null)) - checkEvaluation(ArraySort(a4), Seq(d1, d2)) - checkEvaluation(ArraySort(a5), Seq(null, null)) - checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) - checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) - checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index b83d03025d21..9a613cfe61d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -84,6 +84,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) } + def arraySort(expr: Expression): Expression = { + arraySort(expr, ArraySort.comparator) + } + + def arraySort(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArraySort(expr, createLambda(et, cn, et, cn, f)).bind(validateBinding) + } + def filter(expr: Expression, f: Expression => Expression): Expression = { val ArrayType(et, cn) = expr.dataType ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) @@ -162,6 +171,47 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq("[1, 3, 5]", null, "[4, 6]")) } + test("ArraySort") { + val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(arraySort(a0), Seq(1, 2, 3)) + checkEvaluation(arraySort(a1), Seq[Integer]()) + checkEvaluation(arraySort(a2), Seq("a", "b")) + checkEvaluation(arraySort(a3), Seq("a", "b", null)) + checkEvaluation(arraySort(a4), Seq(d1, d2)) + checkEvaluation(arraySort(a5), Seq(null, null)) + checkEvaluation(arraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(arraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(arraySort(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(arraySort(a0, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(3, 2, 1)) + checkEvaluation(arraySort(a3, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(null, "b", "a")) + checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))), + Seq(d2, d1)) + } + test("MapFilter") { def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { val MapType(kt, vt, vcn) = expr.dataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 84e0eaff2d42..a65906d47741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3334,7 +3334,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) } /** * Remove all elements that equal to element from the given array. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d044638db57..8de0a834b92f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -312,6 +312,86 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("array_sort with lambda functions") { + + spark.udf.register("fAsc", (x: Int, y: Int) => { + if (x < y) -1 + else if (x == y) 0 + else 1 + }) + + spark.udf.register("fDesc", (x: Int, y: Int) => { + if (x < y) 1 + else if (x == y) 0 + else -1 + }) + + spark.udf.register("fString", (x: String, y: String) => { + if (x == null && y == null) 0 + else if (x == null) 1 + else if (y == null) -1 + else if (x < y) 1 + else if (x == y) 0 + else -1 + }) + + spark.udf.register("fStringLength", (x: String, y: String) => { + if (x == null && y == null) 0 + else if (x == null) 1 + else if (y == null) -1 + else if (x.length < y.length) -1 + else if (x.length == y.length) 0 + else 1 + }) + + val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") + checkAnswer( + df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), + Seq( + Row(Seq(1, 2, 2, 3, 5))) + ) + + checkAnswer( + df1.selectExpr("array_sort(a, (x, y) -> fDesc(x, y))"), + Seq( + Row(Seq(5, 3, 2, 2, 1))) + ) + + val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") + checkAnswer( + df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), + Seq( + Row(Seq("dc", "bc", "ab"))) + ) + + val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") + checkAnswer( + df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), + Seq( + Row(Seq("a", "abc", "abcd"))) + ) + + val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), + Array(1, 2)), "x")).toDF("a", "b") + checkAnswer( + df4.selectExpr("array_sort(a, (x, y) -> fAsc(cardinality(x), cardinality(y)))"), + Seq( + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4)))) + ) + + val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") + checkAnswer( + df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), + Seq( + Row(Seq("dc", "bc", "ab", null))) + ) + + spark.sql("drop temporary function fAsc") + spark.sql("drop temporary function fDesc") + spark.sql("drop temporary function fString") + spark.sql("drop temporary function fStringLength") + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -383,7 +463,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(intercept[AnalysisException] { df3.selectExpr("array_sort(a)").collect() - }.getMessage().contains("only supports array input")) + }.getMessage().contains("argument 1 requires array type, however, '`a`' is of string type")) } def testSizeOfArray(sizeOfNull: Any): Unit = {