diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 533c5614885c..bd7473706ca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3968,6 +3968,19 @@ object functions { */ def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) } + /** + * Sorts the input array based on the given comparator function. The comparator will take two + * arguments representing two elements of the array. It returns a negative integer, 0, or a + * positive integer as the first element is less than, equal to, or greater than the second + * element. If the comparator function returns null, the function will fail and raise an error. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_sort(e: Column, comparator: (Column, Column) => Column): Column = withExpr { + new ArraySort(e.expr, createLambda(comparator)) + } + /** * Remove all elements that equal to element from the given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9924fbfbf626..b80925f8638d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -434,6 +434,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }) val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") + checkAnswer( + df1.select(array_sort(col("a"), (x, y) => call_udf("fAsc", x, y))), + Seq( + Row(Seq(1, 2, 2, 3, 5))) + ) + + checkAnswer( + df1.select(array_sort(col("a"), (x, y) => call_udf("fDesc", x, y))), + Seq( + Row(Seq(5, 3, 2, 2, 1))) + ) + checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( @@ -447,6 +459,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") + checkAnswer( + df2.select(array_sort(col("a"), (x, y) => call_udf("fString", x, y))), + Seq( + Row(Seq("dc", "bc", "ab"))) + ) + checkAnswer( df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( @@ -454,6 +472,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") + checkAnswer( + df3.select(array_sort(col("a"), (x, y) => call_udf("fStringLength", x, y))), + Seq( + Row(Seq("a", "abc", "abcd"))) + ) + checkAnswer( df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), Seq( @@ -462,6 +486,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") + checkAnswer( + df4.select(array_sort(col("a"), (x, y) => call_udf("fAsc", size(x), size(y)))), + Seq( + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4)))) + ) + checkAnswer( df4.selectExpr("array_sort(a, (x, y) -> fAsc(cardinality(x), cardinality(y)))"), Seq( @@ -469,6 +499,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") + checkAnswer( + df5.select(array_sort(col("a"), (x, y) => call_udf("fString", x, y))), + Seq( + Row(Seq("dc", "bc", "ab", null))) + ) + checkAnswer( df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( @@ -484,6 +520,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-38130: array_sort with lambda of non-orderable items") { val df6 = Seq((Array[Map[String, Int]](Map("a" -> 1), Map("b" -> 2, "c" -> 3), Map()), "x")).toDF("a", "b") + checkAnswer( + df6.select(array_sort(col("a"), (x, y) => size(x) - size(y))), + Seq( + Row(Seq[Map[String, Int]](Map(), Map("a" -> 1), Map("b" -> 2, "c" -> 3)))) + ) + checkAnswer( df6.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"), Seq(