Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/**
* The internal representation of interval type.
*/
public final class CalendarInterval implements Serializable, Comparable<CalendarInterval> {
public final class CalendarInterval implements Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a MAX_VALUE and MIN_VALUE in this file, shall we remove them as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, remove them too

public final int months;
public final int days;
public final long microseconds;
Expand Down Expand Up @@ -59,29 +59,6 @@ public int hashCode() {
return Objects.hash(months, days, microseconds);
}

@Override
public int compareTo(CalendarInterval that) {
long thisAdjustDays =
this.microseconds / MICROS_PER_DAY + this.days + this.months * DAYS_PER_MONTH;
long thatAdjustDays =
that.microseconds / MICROS_PER_DAY + that.days + that.months * DAYS_PER_MONTH;
long daysDiff = thisAdjustDays - thatAdjustDays;
if (daysDiff == 0) {
long msDiff = (this.microseconds % MICROS_PER_DAY) - (that.microseconds % MICROS_PER_DAY);
if (msDiff == 0) {
return 0;
} else if (msDiff > 0) {
return 1;
} else {
return -1;
}
} else if (daysDiff > 0){
return 1;
} else {
return -1;
}
}

@Override
public String toString() {
if (months == 0 && days == 0 && microseconds == 0) {
Expand Down Expand Up @@ -133,16 +110,4 @@ private void appendUnit(StringBuilder sb, long value, String unit) {
* @throws ArithmeticException if a numeric overflow occurs
*/
public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); }

/**
* A constant holding the minimum value an {@code CalendarInterval} can have.
*/
public static CalendarInterval MIN_VALUE =
new CalendarInterval(Integer.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE);

/**
* A constant holding the maximum value an {@code CalendarInterval} can have.
*/
public static CalendarInterval MAX_VALUE =
new CalendarInterval(Integer.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ class CodegenContext extends Logging {
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case CalendarIntervalType => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
Expand All @@ -630,7 +629,6 @@ class CodegenContext extends Logging {
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case CalendarIntervalType => s"$c1.compareTo($c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder],
def isOrderable(dataType: DataType): Boolean = dataType match {
case NullType => true
case dt: AtomicType => true
case CalendarIntervalType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ object TypeUtils {
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case c: CalendarIntervalType => c.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {

/**
* Types that include numeric types and interval type, which support numeric type calculations,
* i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations.
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.types

import org.apache.spark.annotation.Stable
import org.apache.spark.unsafe.types.CalendarInterval

/**
* The data type representing calendar intervals. The calendar interval is stored internally in
Expand All @@ -40,8 +39,6 @@ class CalendarIntervalType private() extends DataType {

override def simpleString: String = "interval"

val ordering: Ordering[CalendarInterval] = Ordering[CalendarInterval]

private[spark] override def asNullable: CalendarIntervalType = this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,38 +426,33 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}

test("interval overflow check") {
intercept[ArithmeticException](negateExact(new CalendarInterval(Int.MinValue, 0, 0)))
assert(negate(new CalendarInterval(Int.MinValue, 0, 0)) ===
new CalendarInterval(Int.MinValue, 0, 0))
intercept[ArithmeticException](negateExact(CalendarInterval.MIN_VALUE))
assert(negate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE)
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, 1)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 1, 0)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(1, 0, 0)))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))

intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, -1)))
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, -1, 0)))
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(-1, 0, 0)))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))

intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2))
intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5))
val maxMonth = new CalendarInterval(Int.MaxValue, 0, 0)
val minMonth = new CalendarInterval(Int.MinValue, 0, 0)
val oneMonth = new CalendarInterval(1, 0, 0)
val maxDay = new CalendarInterval(0, Int.MaxValue, 0)
val minDay = new CalendarInterval(0, Int.MinValue, 0)
val oneDay = new CalendarInterval(0, 1, 0)
val maxMicros = new CalendarInterval(0, 0, Long.MaxValue)
val minMicros = new CalendarInterval(0, 0, Long.MinValue)
val oneMicros = new CalendarInterval(0, 0, 1)
intercept[ArithmeticException](negateExact(minMonth))
assert(negate(minMonth) === minMonth)

intercept[ArithmeticException](addExact(maxMonth, oneMonth))
intercept[ArithmeticException](addExact(maxDay, oneDay))
intercept[ArithmeticException](addExact(maxMicros, oneMicros))
assert(add(maxMonth, oneMonth) === minMonth)
assert(add(maxDay, oneDay) === minDay)
assert(add(maxMicros, oneMicros) === minMicros)

intercept[ArithmeticException](subtractExact(minDay, oneDay))
intercept[ArithmeticException](subtractExact(minMonth, oneMonth))
intercept[ArithmeticException](subtractExact(minMicros, oneMicros))
assert(subtract(minMonth, oneMonth) === maxMonth)
assert(subtract(minDay, oneDay) === maxDay)
assert(subtract(minMicros, oneMicros) === maxMicros)

intercept[ArithmeticException](multiplyExact(maxMonth, 2))
intercept[ArithmeticException](divideExact(maxDay, 0.5))
}
}
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ class Dataset[T] private[sql](
}
}

private[sql] def numericCalculationSupportedColumns: Seq[Expression] = {
queryExecution.analyzed.output.filter { attr =>
TypeCollection.NumericAndInterval.acceptsType(attr.dataType)
private[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructType, TypeCollection}
import org.apache.spark.sql.types.{NumericType, StructType}

/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
Expand Down Expand Up @@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql](
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}

private[this] def aggregateNumericOrIntervalColumns(
colNames: String*)(f: Expression => AggregateFunction): DataFrame = {
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {

val columnExprs = if (colNames.isEmpty) {
// No columns specified. Use all numeric calculation supported columns.
df.numericCalculationSupportedColumns
// No columns specified. Use all numeric columns.
df.numericColumns
} else {
// Make sure all specified columns are numeric calculation supported columns.
// Make sure all specified columns are numeric.
colNames.map { colName =>
val namedExpr = df.resolve(colName)
if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) {
if (!namedExpr.dataType.isInstanceOf[NumericType]) {
throw new AnalysisException(
s""""$colName" is not a numeric or calendar interval column. """ +
"Aggregation function can only be applied on a numeric or calendar interval column.")
s""""$colName" is not a numeric column. """ +
"Aggregation function can only be applied on a numeric column.")
}
namedExpr
}
Expand Down Expand Up @@ -269,64 +269,63 @@ class RelationalGroupedDataset protected[sql](
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))

/**
* Compute the average value for each numeric or calender interval columns for each group. This
* is an alias for `avg`.
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the average values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def mean(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}

/**
* Compute the max value for each numeric calender interval columns for each group.
* Compute the max value for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the max values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def max(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Max)
aggregateNumericColumns(colNames : _*)(Max)
}

/**
* Compute the mean value for each numeric calender interval columns for each group.
* Compute the mean value for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the mean values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def avg(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}

/**
* Compute the min value for each numeric calender interval column for each group.
* Compute the min value for each numeric column for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the min values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def min(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Min)
aggregateNumericColumns(colNames : _*)(Min)
}

/**
* Compute the sum for each numeric calender interval columns for each group.
* Compute the sum for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the sum for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Sum)
aggregateNumericColumns(colNames : _*)(Sum)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,8 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
}

private[columnar] final class IntervalColumnStats extends ColumnStats {
protected var upper: CalendarInterval = CalendarInterval.MIN_VALUE
protected var lower: CalendarInterval = CalendarInterval.MAX_VALUE

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getInterval(ordinal)
if (value.compareTo(upper) > 0) upper = value
if (value.compareTo(lower) < 0) lower = value
sizeInBytes += CALENDAR_INTERVAL.actualSize(row, ordinal)
count += 1
} else {
Expand All @@ -312,7 +306,7 @@ private[columnar] final class IntervalColumnStats extends ColumnStats {
}

override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
Array[Any](null, null, nullCount, count, sizeInBytes)
}

private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
Expand Down
42 changes: 0 additions & 42 deletions sql/core/src/test/resources/sql-tests/inputs/interval.sql
Original file line number Diff line number Diff line change
@@ -1,47 +1,5 @@
-- test for intervals

-- greater than or equal
select interval '1 day' > interval '23 hour';
select interval '-1 day' >= interval '-23 hour';
select interval '-1 day' > null;
select null > interval '-1 day';

-- less than or equal
select interval '1 minutes' < interval '1 hour';
select interval '-1 day' <= interval '-23 hour';

-- equal
select interval '1 year' = interval '360 days';
select interval '1 year 2 month' = interval '420 days';
select interval '1 year' = interval '365 days';
select interval '1 month' = interval '30 days';
select interval '1 minutes' = interval '1 hour';
select interval '1 minutes' = null;
select null = interval '-1 day';

-- null safe equal
select interval '1 minutes' <=> null;
select null <=> interval '1 minutes';

-- complex interval representation
select INTERVAL '9 years 1 months -1 weeks -4 days -10 hours -46 minutes' > interval '1 minutes';

-- ordering
select cast(v as interval) i from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v) order by i;

-- unlimited days
select interval '1 month 120 days' > interval '2 month';
select interval '1 month 30 days' = interval '2 month';

-- unlimited microseconds
select interval '1 month 29 days 40 hours' > interval '2 month';

-- max
select max(cast(v as interval)) from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v);

-- min
select min(cast(v as interval)) from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v);

-- multiply and divide an interval by a number
select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15');
select interval 4 month 2 weeks 3 microseconds * 1.5;
Expand Down
Loading