|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.optimizer |
19 | 19 |
|
20 | | -import java.time.LocalDate |
| 20 | +import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZoneId} |
21 | 21 |
|
22 | 22 | import org.apache.spark.sql.catalyst.dsl.plans._ |
23 | 23 | import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal} |
24 | 24 | import org.apache.spark.sql.catalyst.plans.PlanTest |
25 | 25 | import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} |
26 | 26 | import org.apache.spark.sql.catalyst.rules.RuleExecutor |
27 | | -import org.apache.spark.sql.types.DateType |
| 27 | +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MINUTE |
| 28 | +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros} |
| 29 | +import org.apache.spark.sql.types.{AtomicType, DateType, TimestampNTZType, TimestampType} |
28 | 30 |
|
29 | 31 | class SpecialDatetimeValuesSuite extends PlanTest { |
30 | 32 | object Optimize extends RuleExecutor[LogicalPlan] { |
@@ -55,4 +57,45 @@ class SpecialDatetimeValuesSuite extends PlanTest { |
55 | 57 | assert(expected === lits.toSet) |
56 | 58 | } |
57 | 59 | } |
| 60 | + |
| 61 | + private def testSpecialTs(tsType: AtomicType, expected: Set[Long], zoneId: ZoneId): Unit = { |
| 62 | + val in = Project(Seq( |
| 63 | + Alias(Cast(Literal("epoch"), tsType, Some(zoneId.getId)), "epoch")(), |
| 64 | + Alias(Cast(Literal("now"), tsType, Some(zoneId.getId)), "now")(), |
| 65 | + Alias(Cast(Literal("tomorrow"), tsType, Some(zoneId.getId)), "tomorrow")(), |
| 66 | + Alias(Cast(Literal("yesterday"), tsType, Some(zoneId.getId)), "yesterday")()), |
| 67 | + LocalRelation()) |
| 68 | + |
| 69 | + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] |
| 70 | + val lits = new scala.collection.mutable.ArrayBuffer[Long] |
| 71 | + plan.transformAllExpressions { case e: Literal if e.dataType == tsType => |
| 72 | + lits += e.value.asInstanceOf[Long] |
| 73 | + e |
| 74 | + } |
| 75 | + assert(lits.forall(ts => expected.exists(ets => Math.abs(ets -ts) <= MICROS_PER_MINUTE))) |
| 76 | + } |
| 77 | + |
| 78 | + test("special timestamp_ltz values") { |
| 79 | + testSpecialDatetimeValues { zoneId => |
| 80 | + val expected = Set( |
| 81 | + Instant.ofEpochSecond(0), |
| 82 | + Instant.now(), |
| 83 | + Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).plusDays(1).toInstant, |
| 84 | + Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).minusDays(1).toInstant |
| 85 | + ).map(instantToMicros) |
| 86 | + testSpecialTs(TimestampType, expected, zoneId) |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + test("special timestamp_ntz values") { |
| 91 | + testSpecialDatetimeValues { zoneId => |
| 92 | + val expected = Set( |
| 93 | + LocalDateTime.of(1970, 1, 1, 0, 0), |
| 94 | + LocalDateTime.now(), |
| 95 | + LocalDateTime.now().`with`(LocalTime.MIDNIGHT).plusDays(1), |
| 96 | + LocalDateTime.now().`with`(LocalTime.MIDNIGHT).minusDays(1) |
| 97 | + ).map(localDateTimeToMicros) |
| 98 | + testSpecialTs(TimestampNTZType, expected, zoneId) |
| 99 | + } |
| 100 | + } |
58 | 101 | } |
0 commit comments