From bca199ffa5f8a8972fc8b2ad0d065c3b83dd4614 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 2 Aug 2024 08:59:35 +0800 Subject: [PATCH] Remove ArraySortLike trait --- .../expressions/collectionOperations.scala | 196 ++++++++---------- 1 file changed, 89 insertions(+), 107 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7e9bf989e9cf..516f521bc964 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,7 +27,6 @@ import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} @@ -1025,15 +1024,73 @@ case class MapSort(base: Expression) } /** - * Common base class for [[SortArray]] and [[ArraySort]]. + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. */ -trait ArraySortLike extends ExpectsInputTypes { - protected def arrayExpression: Expression +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. NaN is greater than any non-NaN + elements for double/float type. Null elements will be placed at the beginning of the returned + array in ascending order or at the end of the returned array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), false); + ["d","c","b","a",null] + """, + group = "array_funcs", + since = "1.5.0") +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ExpectsInputTypes with NullIntolerant with QueryErrorsBase { - protected def nullOrder: NullOrder + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(1), + "requiredType" -> toSQLType(BooleanType), + "inputSql" -> toSQLExpr(ascendingOrder), + "inputType" -> toSQLType(ascendingOrder.dataType)) + ) + } + case ArrayType(dt, _) => + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(base.dataType) + ) + ) + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(base), + "inputType" -> toSQLType(base.dataType)) + ) + } @transient private lazy val lt: Comparator[Any] = { - val ordering = arrayExpression.dataType match { + val ordering = base.dataType match { case _ @ ArrayType(n, _) => PhysicalDataType.ordering(n) } @@ -1041,9 +1098,9 @@ trait ArraySortLike extends ExpectsInputTypes { if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - nullOrder + -1 } else if (o2 == null) { - -nullOrder + 1 } else { ordering.compare(o1, o2) } @@ -1051,7 +1108,7 @@ trait ArraySortLike extends ExpectsInputTypes { } @transient private lazy val gt: Comparator[Any] = { - val ordering = arrayExpression.dataType match { + val ordering = base.dataType match { case _ @ ArrayType(n, _) => PhysicalDataType.ordering(n) } @@ -1060,9 +1117,9 @@ trait ArraySortLike extends ExpectsInputTypes { if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -nullOrder + 1 } else if (o2 == null) { - nullOrder + -1 } else { ordering.compare(o2, o1) } @@ -1070,12 +1127,12 @@ trait ArraySortLike extends ExpectsInputTypes { } @transient lazy val elementType: DataType = - arrayExpression.dataType.asInstanceOf[ArrayType].elementType + base.dataType.asInstanceOf[ArrayType].elementType private def resultArrayElementNullable: Boolean = - arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + base.dataType.asInstanceOf[ArrayType].containsNull - def sortEval(array: Any, ascending: Boolean): Any = { + private def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { java.util.Arrays.sort(data, if (ascending) lt else gt) @@ -1083,7 +1140,11 @@ trait ArraySortLike extends ExpectsInputTypes { new GenericArrayData(data.asInstanceOf[Array[Any]]) } - def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + private def sortCodegen( + ctx: CodegenContext, + ev: ExprCode, + base: String, + order: String): String = { val genericArrayData = classOf[GenericArrayData].getName val unsafeArrayData = classOf[UnsafeArrayData].getName val array = ctx.freshName("array") @@ -1111,18 +1172,18 @@ trait ArraySortLike extends ExpectsInputTypes { val canPerformFastSort = CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !resultArrayElementNullable val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { - val javaType = CodeGenerator.javaType(elementType) - val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |if ($order) { - | $javaType[] $array = $base.to${primitiveTypeName}Array(); - | java.util.Arrays.sort($array); - | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); - |} else + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else """.stripMargin - } else { - "" - } + } else { + "" + } s""" |$nonNullPrimitiveAscendingSort |{ @@ -1133,9 +1194,9 @@ trait ArraySortLike extends ExpectsInputTypes { | if ($o1 == null && $o2 == null) { | return 0; | } else if ($o1 == null) { - | return $sortOrder * $nullOrder; + | return -$sortOrder; | } else if ($o2 == null) { - | return -$sortOrder * $nullOrder; + | return $sortOrder; | } | $comp | return $sortOrder * $c; @@ -1147,85 +1208,6 @@ trait ArraySortLike extends ExpectsInputTypes { } } -} - -object ArraySortLike { - type NullOrder = Int - // Least: place null element at the first of the array for ascending order - // Greatest: place null element at the end of the array for ascending order - object NullOrder { - val Least: NullOrder = -1 - val Greatest: NullOrder = 1 - } -} - -/** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order - according to the natural ordering of the array elements. NaN is greater than any non-NaN - elements for double/float type. Null elements will be placed at the beginning of the returned - array in ascending order or at the end of the returned array in descending order. - """, - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); - [null,"a","b","c","d"] - """, - group = "array_funcs", - since = "1.5.0") -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike with NullIntolerant with QueryErrorsBase { - - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def arrayExpression: Expression = base - override def nullOrder: NullOrder = NullOrder.Least - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(1), - "requiredType" -> toSQLType(BooleanType), - "inputSql" -> toSQLExpr(ascendingOrder), - "inputType" -> toSQLType(ascendingOrder.dataType)) - ) - } - case ArrayType(dt, _) => - DataTypeMismatch( - errorSubClass = "INVALID_ORDERING_TYPE", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(base.dataType) - ) - ) - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(0), - "requiredType" -> toSQLType(ArrayType), - "inputSql" -> toSQLExpr(base), - "inputType" -> toSQLType(base.dataType)) - ) - } - override def nullSafeEval(array: Any, ascending: Any): Any = { sortEval(array, ascending.asInstanceOf[Boolean]) }