Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
Expand Down Expand Up @@ -1025,33 +1024,91 @@ case class MapSort(base: Expression)
}

/**
* Common base class for [[SortArray]] and [[ArraySort]].
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
trait ArraySortLike extends ExpectsInputTypes {
protected def arrayExpression: Expression
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. NaN is greater than any non-NaN
elements for double/float type. Null elements will be placed at the beginning of the returned
array in ascending order or at the end of the returned array in descending order.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
[null,"a","b","c","d"]
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), false);
["d","c","b","a",null]
""",
group = "array_funcs",
since = "1.5.0")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with NullIntolerant with QueryErrorsBase {

protected def nullOrder: NullOrder
def this(e: Expression) = this(e, Literal(true))

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> toSQLType(BooleanType),
"inputSql" -> toSQLExpr(ascendingOrder),
"inputType" -> toSQLType(ascendingOrder.dataType))
)
}
case ArrayType(dt, _) =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(base.dataType)
)
)
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(base),
"inputType" -> toSQLType(base.dataType))
)
}

@transient private lazy val lt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
val ordering = base.dataType match {
case _ @ ArrayType(n, _) =>
PhysicalDataType.ordering(n)
}
(o1: Any, o2: Any) => {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
nullOrder
-1
} else if (o2 == null) {
-nullOrder
1
} else {
ordering.compare(o1, o2)
}
}
}

@transient private lazy val gt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
val ordering = base.dataType match {
case _ @ ArrayType(n, _) =>
PhysicalDataType.ordering(n)
}
Expand All @@ -1060,30 +1117,34 @@ trait ArraySortLike extends ExpectsInputTypes {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-nullOrder
1
} else if (o2 == null) {
nullOrder
-1
} else {
ordering.compare(o2, o1)
}
}
}

@transient lazy val elementType: DataType =
arrayExpression.dataType.asInstanceOf[ArrayType].elementType
base.dataType.asInstanceOf[ArrayType].elementType

private def resultArrayElementNullable: Boolean =
arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
base.dataType.asInstanceOf[ArrayType].containsNull

def sortEval(array: Any, ascending: Boolean): Any = {
private def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}

def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
private def sortCodegen(
ctx: CodegenContext,
ev: ExprCode,
base: String,
order: String): String = {
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val array = ctx.freshName("array")
Expand Down Expand Up @@ -1111,18 +1172,18 @@ trait ArraySortLike extends ExpectsInputTypes {
val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) &&
elementType != BooleanType && !resultArrayElementNullable
val nonNullPrimitiveAscendingSort = if (canPerformFastSort) {
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($order) {
| $javaType[] $array = $base.to${primitiveTypeName}Array();
| java.util.Arrays.sort($array);
| ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
|} else
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($order) {
| $javaType[] $array = $base.to${primitiveTypeName}Array();
| java.util.Arrays.sort($array);
| ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
|} else
""".stripMargin
} else {
""
}
} else {
""
}
s"""
|$nonNullPrimitiveAscendingSort
|{
Expand All @@ -1133,9 +1194,9 @@ trait ArraySortLike extends ExpectsInputTypes {
| if ($o1 == null && $o2 == null) {
| return 0;
| } else if ($o1 == null) {
| return $sortOrder * $nullOrder;
| return -$sortOrder;
| } else if ($o2 == null) {
| return -$sortOrder * $nullOrder;
| return $sortOrder;
| }
| $comp
| return $sortOrder * $c;
Expand All @@ -1147,85 +1208,6 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}

}

object ArraySortLike {
type NullOrder = Int
// Least: place null element at the first of the array for ascending order
// Greatest: place null element at the end of the array for ascending order
object NullOrder {
val Least: NullOrder = -1
val Greatest: NullOrder = 1
}
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. NaN is greater than any non-NaN
elements for double/float type. Null elements will be placed at the beginning of the returned
array in ascending order or at the end of the returned array in descending order.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
[null,"a","b","c","d"]
""",
group = "array_funcs",
since = "1.5.0")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ArraySortLike with NullIntolerant with QueryErrorsBase {

def this(e: Expression) = this(e, Literal(true))

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def arrayExpression: Expression = base
override def nullOrder: NullOrder = NullOrder.Least

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> toSQLType(BooleanType),
"inputSql" -> toSQLExpr(ascendingOrder),
"inputType" -> toSQLType(ascendingOrder.dataType))
)
}
case ArrayType(dt, _) =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"dataType" -> toSQLType(base.dataType)
)
)
case _ =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(base),
"inputType" -> toSQLType(base.dataType))
)
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
sortEval(array, ascending.asInstanceOf[Boolean])
}
Expand Down