diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4f38f8276f1a..233bfc90434a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.time.ZoneId import java.util.Locale import java.util.concurrent.TimeUnit._ @@ -64,6 +65,7 @@ object Cast { case (StringType, DateType) => true case (TimestampType, DateType) => true + case (DoubleType, DateType) => true case (StringType, CalendarIntervalType) => true @@ -492,6 +494,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s, zoneId).orNull) + case DoubleType => + buildCast[Double](_, daysSinceEpoch => daysSinceEpoch.toInt) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -718,8 +722,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) - case DateType => - buildCast[Int](_, d => null) + case DateType => _.asInstanceOf[Int].toDouble case TimestampType => buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => @@ -1143,6 +1146,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evNull = true; } """ + case DoubleType => (c, evPrim, evNull) => code"$evPrim = (int) $c;" case TimestampType => val zid = getZoneId() (c, evPrim, evNull) => @@ -1605,7 +1609,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => code"$evNull = true;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" case TimestampType => (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index d3ce1f8d331a..713fe01dcb09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -40,10 +40,17 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override def children: Seq[Expression] = child :: Nil - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, DateType) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") + override def checkInputDataTypes(): TypeCheckResult = { + val isNumeric = TypeUtils.checkForNumericExpr(child.dataType, "function average") + + if(isNumeric.isFailure && child.dataType == DateType) { + TypeCheckResult.TypeCheckSuccess + } else { + isNumeric + } + } override def nullable: Boolean = true @@ -53,6 +60,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) + case DateType => DateType case _ => DoubleType } @@ -77,9 +85,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit ) // If all input are nulls, count will be 0 and we will get null after the division. - override lazy val evaluateExpression = child.dataType match { + override lazy val evaluateExpression: Expression = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) + case _: DateType => + (sum / count).cast(resultType) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 76b1944d22a6..ce9d41d791d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -356,7 +356,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, IntegerType), null) checkEvaluation(cast(d, LongType), null) checkEvaluation(cast(d, FloatType), null) - checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DoubleType), 0.0d) checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") @@ -644,7 +644,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1, DateType).checkInputDataTypes().isFailure) assert(cast(1L, DateType).checkInputDataTypes().isFailure) assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) - assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } test("SPARK-20302 cast with same structure") { @@ -1319,6 +1318,21 @@ class CastSuite extends CastSuiteBase { } } + private val dateDaysSinceEpoch = 18389.0 // Days since epoch (1970-01-01) + private val date = Date.valueOf("2020-05-07") + + test("SPARK-10520: Cast a Date to Double") { + withDefaultTimeZone(UTC) { + checkEvaluation(cast(Literal(date), DoubleType), dateDaysSinceEpoch) + } + } + + test("SPARK-10520: Cast a Double to Date") { + withDefaultTimeZone(UTC) { + checkEvaluation(cast(Literal(dateDaysSinceEpoch), DateType), date) + } + } + test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { withDefaultTimeZone(UTC) { val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e954e2bf1c46..45a9d08b0daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql +import java.sql.Date + import scala.util.Random import org.scalatest.matchers.must.Matchers.the +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -318,6 +321,14 @@ class DataFrameAggregateSuite extends QueryTest Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil) } + test("SPARK-10520: date average") { + withDefaultTimeZone(UTC) { + checkAnswer( + testDataDates.agg(avg($"a")), + Row(new Date(2011, 4, 3))) + } + } + test("null average") { checkAnswer( testData3.agg(avg($"b")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index c51faaf10f5d..9dd5a416265d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} @@ -73,6 +74,16 @@ private[sql] trait SQLTestData { self => df } + protected lazy val testDataDates: DataFrame = { + val df = spark.sparkContext.parallelize( + TestDataDate(new Date(2000, 1, 1)) :: + TestDataDate(new Date(2010, 1, 1)) :: + TestDataDate(new Date(2015, 1, 1)) :: + TestDataDate(new Date(2020, 1, 1)) :: Nil, 2).toDF() + df.createOrReplaceTempView("testDates") + df + } + protected lazy val negativeData: DataFrame = { val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() @@ -326,6 +337,7 @@ private[sql] object SQLTestData { case class TestData(key: Int, value: String) case class TestData2(a: Int, b: Int) case class TestData3(a: Int, b: Option[Int]) + case class TestDataDate(a: Date) case class LargeAndSmallInts(a: Int, b: Int) case class DecimalData(a: BigDecimal, b: BigDecimal) case class BinaryData(a: Array[Byte], b: Int)