diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index f2d06e793f9dd..b567ac302b840 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -44,7 +44,7 @@ * @since 3.0.0 */ @Unstable -public final class CalendarInterval implements Serializable { +public final class CalendarInterval implements Serializable, Comparable { // NOTE: If you're moving or renaming this file, you should also update Unidoc configuration // specified in 'SparkBuild.scala'. public final int months; @@ -127,4 +127,26 @@ private void appendUnit(StringBuilder sb, long value, String unit) { * @throws ArithmeticException if a numeric overflow occurs */ public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); } + + /** + * This method is not used to order CalendarInterval instances, as they are not orderable and + * cannot be used in a ORDER BY statement. + * Instead, it is used to find identical interval instances for aggregation purposes. + * It compares the 'months', 'days', and 'microseconds' fields of this CalendarInterval + * with another instance. The comparison is done first on the 'months', then on the 'days', + * and finally on the 'microseconds'. + * + * @param o The CalendarInterval instance to compare with. + * @return Zero if this object is equal to the specified object, and non-zero otherwise + */ + @Override + public int compareTo(CalendarInterval o) { + if (this.months != o.months) { + return Integer.compare(this.months, o.months); + } else if (this.days != o.days) { + return Integer.compare(this.days, o.days); + } else { + return Long.compare(this.microseconds, o.microseconds); + } + } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index b8b7105233656..0a1ee279316f1 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -76,6 +76,22 @@ public void toStringTest() { i.toString()); } + @Test + public void compareToTest() { + CalendarInterval i = new CalendarInterval(0, 0, 0); + + assertEquals(i.compareTo(new CalendarInterval(0, 0, 0)), 0); + assertEquals(i.compareTo(new CalendarInterval(0, 0, 1)), -1); + assertEquals(i.compareTo(new CalendarInterval(0, 1, 0)), -1); + assertEquals(i.compareTo(new CalendarInterval(0, 1, -1)), -1); + assertEquals(i.compareTo(new CalendarInterval(1, 0, 0)), -1); + assertEquals(i.compareTo(new CalendarInterval(1, 0, -1)), -1); + assertEquals(i.compareTo(new CalendarInterval(0, 0, -1)), 1); + assertEquals(i.compareTo(new CalendarInterval(0, -1, 0)), 1); + assertEquals(i.compareTo(new CalendarInterval(-1, 0, 0)), 1); + assertEquals(i.compareTo(new CalendarInterval(-1, 0, 1)), 1); + } + @Test public void periodAndDurationTest() { CalendarInterval interval = new CalendarInterval(120, -40, 123456); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 41071d031d2e0..2bbe730d4cfb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -194,7 +194,7 @@ object ExprUtils extends QueryErrorsBase { } // Check if the data type of expr is orderable. - if (!RowOrdering.isOrderable(expr.dataType)) { + if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) { expr.failAnalysis( errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", messageParameters = Map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d10e4a1ced1bd..c8c2d5558b148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -627,6 +627,7 @@ class CodegenContext extends Logging { case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) + case CalendarIntervalType => s"$c1.equals($c2)" case NullType => "false" case _ => throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError( @@ -652,6 +653,7 @@ class CodegenContext extends Logging { // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)" + case CalendarIntervalType => s"$c1.compareTo($c2)" case NullType => "0" case array: ArrayType => val elementType = array.elementType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 1972aeb382658..278c1fc3f73b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -77,9 +77,11 @@ object AggUtils { child: SparkPlan): SparkPlan = { val useHash = Aggregate.supportsHashAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + + val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf) val forceSortAggregate = forceApplySortAggregate(child.conf) - if (useHash && !forceSortAggregate) { + if (useHash && !forceSortAggregate && !forceObjHashAggregate) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, isStreaming = isStreaming, @@ -94,7 +96,7 @@ object AggUtils { val objectHashEnabled = child.conf.useObjectHashAggregation val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions) - if (objectHashEnabled && useObjectHash && !forceSortAggregate) { + if (forceObjHashAggregate || (objectHashEnabled && useObjectHash && !forceSortAggregate)) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, isStreaming = isStreaming, @@ -589,4 +591,13 @@ object AggUtils { Utils.isTesting && conf.getConfString("spark.sql.test.forceApplySortAggregate", "false") == "true" } + + /** + * Returns whether a object hash aggregate should be force applied. + * The config key is hard-coded because it's testing only and should not be exposed. + */ + private def forceApplyObjectHashAggregate(conf: SQLConf): Boolean = { + Utils.isTesting && + conf.getConfString("spark.sql.test.forceApplyObjectHashAggregate", "false") == "true" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index c33820ed85e53..8a88ad0a57e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -174,6 +174,7 @@ abstract class HashMapGenerator( """ } case StringType => hashBytes(s"$input.getBytes()") + case CalendarIntervalType => hashInt(s"$input.hashCode()") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3691d76d25122..0ab8926c016ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} +import org.apache.spark.unsafe.types.CalendarInterval case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -2125,6 +2126,37 @@ class DataFrameAggregateSuite extends QueryTest Seq(Row(1)) ) } + + test("SPARK-46536 Support GROUP BY CalendarIntervalType") { + val numRows = 50 + val configurations = Seq( + Seq.empty[(String, String)], // hash aggregate is used by default + Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", + "spark.sql.TungstenAggregate.testFallbackStartsAt" -> "1, 10"), + Seq("spark.sql.test.forceApplyObjectHashAggregate" -> "true"), + Seq( + "spark.sql.test.forceApplyObjectHashAggregate" -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"), + Seq("spark.sql.test.forceApplySortAggregate" -> "true") + ) + + val dfSame = (0 until numRows) + .map(_ => Tuple1(new CalendarInterval(1, 2, 3))) + .toDF("c0") + + val dfDifferent = (0 until numRows) + .map(i => Tuple1(new CalendarInterval(i, i, i))) + .toDF("c0") + + for (conf <- configurations) { + withSQLConf(conf: _*) { + assert(createAggregate(dfSame).count() == 1) + assert(createAggregate(dfDifferent).count() == numRows) + } + } + + def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) + } } case class B(c: Option[Double])