Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* @since 3.0.0
*/
@Unstable
public final class CalendarInterval implements Serializable {
public final class CalendarInterval implements Serializable, Comparable<CalendarInterval> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Please check the behavior
#27262
I'm not sure. @yaooqinn

// NOTE: If you're moving or renaming this file, you should also update Unidoc configuration
// specified in 'SparkBuild.scala'.
public final int months;
Expand Down Expand Up @@ -127,4 +127,26 @@ private void appendUnit(StringBuilder sb, long value, String unit) {
* @throws ArithmeticException if a numeric overflow occurs
*/
public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); }

/**
* This method is not used to order CalendarInterval instances, as they are not orderable and
* cannot be used in a ORDER BY statement.
* Instead, it is used to find identical interval instances for aggregation purposes.
* It compares the 'months', 'days', and 'microseconds' fields of this CalendarInterval
* with another instance. The comparison is done first on the 'months', then on the 'days',
* and finally on the 'microseconds'.
*
* @param o The CalendarInterval instance to compare with.
* @return Zero if this object is equal to the specified object, and non-zero otherwise
*/
@Override
public int compareTo(CalendarInterval o) {
if (this.months != o.months) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing intervals does not necessarily short circuits via months. We could result in 1 month > 0 months 32 days, which is wrong, obviously.

Besides, 1 month can be 28 ~ 30 days, making the legacy calendar interval type uncomparable

Copy link
Contributor

@cloud-fan cloud-fan Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some comments to explain that this is alphabet ordering. It does not have actual meaning but just makes it possible to find identical interval instances.

We should do the same thing for map type so that we can group by map values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stefankandic did you generate this using IDEA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comments.

@cloud-fan method was generated by intellij but I implemented the logic

return Integer.compare(this.months, o.months);
} else if (this.days != o.days) {
return Integer.compare(this.days, o.days);
} else {
return Long.compare(this.microseconds, o.microseconds);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ public void toStringTest() {
i.toString());
}

@Test
public void compareToTest() {
CalendarInterval i = new CalendarInterval(0, 0, 0);

assertEquals(i.compareTo(new CalendarInterval(0, 0, 0)), 0);
assertEquals(i.compareTo(new CalendarInterval(0, 0, 1)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 1, 0)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 1, -1)), -1);
assertEquals(i.compareTo(new CalendarInterval(1, 0, 0)), -1);
assertEquals(i.compareTo(new CalendarInterval(1, 0, -1)), -1);
assertEquals(i.compareTo(new CalendarInterval(0, 0, -1)), 1);
assertEquals(i.compareTo(new CalendarInterval(0, -1, 0)), 1);
assertEquals(i.compareTo(new CalendarInterval(-1, 0, 0)), 1);
assertEquals(i.compareTo(new CalendarInterval(-1, 0, 1)), 1);
}

@Test
public void periodAndDurationTest() {
CalendarInterval interval = new CalendarInterval(120, -40, 123456);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ object ExprUtils extends QueryErrorsBase {
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
messageParameters = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ class CodegenContext extends Logging {
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case CalendarIntervalType => s"$c1.equals($c2)"
case NullType => "false"
case _ =>
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError(
Expand All @@ -652,6 +653,7 @@ class CodegenContext extends Logging {
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)"
case CalendarIntervalType => s"$c1.compareTo($c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ object AggUtils {
child: SparkPlan): SparkPlan = {
val useHash = Aggregate.supportsHashAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))

val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf)
val forceSortAggregate = forceApplySortAggregate(child.conf)

if (useHash && !forceSortAggregate) {
if (useHash && !forceSortAggregate && !forceObjHashAggregate) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
isStreaming = isStreaming,
Expand All @@ -94,7 +96,7 @@ object AggUtils {
val objectHashEnabled = child.conf.useObjectHashAggregation
val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions)

if (objectHashEnabled && useObjectHash && !forceSortAggregate) {
if (forceObjHashAggregate || (objectHashEnabled && useObjectHash && !forceSortAggregate)) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
isStreaming = isStreaming,
Expand Down Expand Up @@ -589,4 +591,13 @@ object AggUtils {
Utils.isTesting &&
conf.getConfString("spark.sql.test.forceApplySortAggregate", "false") == "true"
}

/**
* Returns whether a object hash aggregate should be force applied.
* The config key is hard-coded because it's testing only and should not be exposed.
*/
private def forceApplyObjectHashAggregate(conf: SQLConf): Boolean = {
Utils.isTesting &&
conf.getConfString("spark.sql.test.forceApplyObjectHashAggregate", "false") == "true"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ abstract class HashMapGenerator(
"""
}
case StringType => hashBytes(s"$input.getBytes()")
case CalendarIntervalType => hashInt(s"$input.hashCode()")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.test.SQLTestData.DecimalData
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.CalendarInterval

case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)

Expand Down Expand Up @@ -2125,6 +2126,37 @@ class DataFrameAggregateSuite extends QueryTest
Seq(Row(1))
)
}

test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
val numRows = 50
val configurations = Seq(
Seq.empty[(String, String)], // hash aggregate is used by default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to disable codegen in order to hit the fallback logic, but hopefully it now tests it properly

Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
"spark.sql.TungstenAggregate.testFallbackStartsAt" -> "1, 10"),
Seq("spark.sql.test.forceApplyObjectHashAggregate" -> "true"),
Seq(
"spark.sql.test.forceApplyObjectHashAggregate" -> "true",
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"),
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
)

val dfSame = (0 until numRows)
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
.toDF("c0")

val dfDifferent = (0 until numRows)
.map(i => Tuple1(new CalendarInterval(i, i, i)))
.toDF("c0")

for (conf <- configurations) {
withSQLConf(conf: _*) {
assert(createAggregate(dfSame).count() == 1)
assert(createAggregate(dfDifferent).count() == numRows)
}
}

def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*"))
}
}

case class B(c: Option[Double])
Expand Down