diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 844ca7a8e99c..e03a2d37b9b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1019,6 +1019,21 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def desc: Column = withExpr { SortOrder(expr, Descending) } + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in descending order with NULLS FIRST. + * df.sort(df("age").desc_nulls_first) + * + * // Java + * df.sort(df.col("age").desc_nulls_first()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + /** * Returns an ordering used in sorting. * {{{ @@ -1034,6 +1049,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def asc: Column = withExpr { SortOrder(expr, Ascending) } + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order with NULLS LAST. + * df.sort(df("age").asc_nulls_last) + * + * // Java + * df.sort(df.col("age").asc_nulls_last()); + * }}} + * + * @group expr_ops + * @since 2.1.0 + */ + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + + /** * Prints the expression to the console for debugging purpose. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 18e736ab6986..cdb5b6f89a34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -118,6 +118,19 @@ object functions { */ def asc(columnName: String): Column = Column(columnName).asc + /** + * Returns a sort expression based on ascending order of the column with NULLS LAST. + * {{{ + * // Sort by dept in ascending order nulls first, and then age in descending order. + * df.sort(asc_nulls_last("dept"), desc("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last + + /** * Returns a sort expression based on the descending order of the column. * {{{ @@ -130,6 +143,19 @@ object functions { */ def desc(columnName: String): Column = Column(columnName).desc + /** + * Returns a sort expression based on the descending order of the column with NULLS FIRST. + * {{{ + * // Sort by dept in ascending order, and then age in descending order NULLS FIRST. + * df.sort(asc("dept"), desc_nulls_first("age")) + * }}} + * + * @group sort_funcs + * @since 2.1.0 + */ + def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first + + ////////////////////////////////////////////////////////////////////////////////////////////// // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c2d256bdd335..f3fcbfe383c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -325,6 +325,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(6)) } + test("sorting with null ordering") { + checkAnswer( + nullableData.orderBy('a.asc_nulls_last, 'b.desc_nulls_first), + Seq( + Row(2, null), Row(2, "B"), Row(3, null), Row(4, "a"), + Row(5, "A"), Row(null, "c"), Row(null, "b") + ) + ) + + checkAnswer( + nullableData.orderBy(asc_nulls_last("a"), desc_nulls_first("b")), + Seq( + Row(2, null), Row(2, "B"), Row(3, null), Row(4, "a"), + Row(5, "A"), Row(null, "c"), Row(null, "b") + ) + ) + + checkAnswer( + nullableData.orderBy('a.desc_nulls_first, 'b.asc_nulls_last), + Seq( + Row(null, "b"), Row(null, "c"), Row(5, "A"), Row(4, "a"), + Row(3, null), Row(2, "B"), Row(2, null) + ) + ) + + checkAnswer( + nullableData.orderBy(desc_nulls_first("a"), asc_nulls_last("b")), + Seq( + Row(null, "b"), Row(null, "c"), Row(5, "A"), Row(4, "a"), + Row(3, null), Row(2, "B"), Row(2, null) + ) + ) + } + test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e5215..ee31d5a739ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -169,6 +169,20 @@ private[sql] trait SQLTestData { self => rdd } + protected lazy val nullableData: DataFrame = { + val df = spark.sparkContext.parallelize( + NullableRecord(4, "a") :: + NullableRecord(null, "c") :: + NullableRecord(2, null) :: + NullableRecord(null, "b") :: + NullableRecord(3, null) :: + NullableRecord(5, "A") :: + NullableRecord(2, "B") :: Nil, 2 + ).toDF("a", "b") + df.createOrReplaceTempView("nullableData") + df + } + protected lazy val nullInts: DataFrame = { val df = spark.sparkContext.parallelize( NullInts(1) :: @@ -305,6 +319,7 @@ private[sql] object SQLTestData { case class IntField(i: Int) case class NullInts(a: Integer) case class NullStrings(n: Int, s: String) + case class NullableRecord(n: Integer, s: String) case class TableName(tableName: String) case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double)