Skip to content

Commit 8075b25

Browse files
committed
Use Instant for CurrentDate, more test coverage
1 parent ff68d5b commit 8075b25

File tree

2 files changed

+47
-41
lines changed

2 files changed

+47
-41
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import java.time.{Instant, LocalDateTime}
20+
import java.time.{Instant, LocalDate, LocalDateTime}
2121

2222
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
2626
import org.apache.spark.sql.catalyst.trees.TreePattern._
27-
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, currentDate, instantToMicros, localDateTimeToMicros}
27+
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros, localDateToDays}
2828
import org.apache.spark.sql.connector.catalog.CatalogManager
2929
import org.apache.spark.sql.types._
3030
import org.apache.spark.util.Utils
@@ -82,7 +82,9 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
8282
plan.transformDownWithSubqueries {
8383
case subQuery =>
8484
subQuery.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
85-
case cd: CurrentDate => Literal.create(currentDate(cd.zoneId).asInstanceOf[Int], DateType)
85+
case cd: CurrentDate =>
86+
Literal.create(
87+
localDateToDays(LocalDate.ofInstant(instant, cd.zoneId)).asInstanceOf[Int], DateType)
8688
case CurrentTimestamp() | Now() => currentTime
8789
case CurrentTimeZone() => timezone
8890
case localTimestamp: LocalTimestamp =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import java.time.{LocalDateTime, ZoneId}
2121

2222
import scala.collection.JavaConverters.mapAsScalaMap
23-
import scala.collection.mutable
23+
import scala.concurrent.duration._
2424

2525
import org.apache.spark.sql.catalyst.dsl.plans._
26-
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimeZone, CurrentTimestamp, InSubquery, ListQuery, Literal, LocalTimestamp}
26+
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
2727
import org.apache.spark.sql.catalyst.plans.PlanTest
2828
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
2929
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -44,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
4444
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
4545
val max = (System.currentTimeMillis() + 1) * 1000
4646

47-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
48-
plan.transformAllExpressions { case e: Literal =>
49-
lits += e.value.asInstanceOf[Long]
50-
e
51-
}
47+
val lits = literals[Long](plan)
5248
assert(lits.size == 2)
5349
assert(lits(0) >= min && lits(0) <= max)
5450
assert(lits(1) >= min && lits(1) <= max)
@@ -62,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
6258
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
6359
val max = DateTimeUtils.currentDate(ZoneId.systemDefault())
6460

65-
val lits = new scala.collection.mutable.ArrayBuffer[Int]
66-
plan.transformAllExpressions { case e: Literal =>
67-
lits += e.value.asInstanceOf[Int]
68-
e
69-
}
61+
val lits = literals[Int](plan)
7062
assert(lits.size == 2)
7163
assert(lits(0) >= min && lits(0) <= max)
7264
assert(lits(1) >= min && lits(1) <= max)
@@ -76,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
7668
test("SPARK-33469: Add current_timezone function") {
7769
val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
7870
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
79-
val lits = new scala.collection.mutable.ArrayBuffer[String]
80-
plan.transformAllExpressions { case e: Literal =>
81-
lits += e.value.asInstanceOf[UTF8String].toString
82-
e
83-
}
71+
val lits = literals[UTF8String](plan)
8472
assert(lits.size == 1)
85-
assert(lits.head == SQLConf.get.sessionLocalTimeZone)
73+
assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
8674
}
8775

8876
test("analyzer should replace localtimestamp with literals") {
@@ -95,11 +83,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
9583
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
9684
val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))
9785

98-
val lits = new scala.collection.mutable.ArrayBuffer[Long]
99-
plan.transformAllExpressions { case e: Literal =>
100-
lits += e.value.asInstanceOf[Long]
101-
e
102-
}
86+
val lits = literals[Long](plan)
10387
assert(lits.size == 2)
10488
assert(lits(0) >= min && lits(0) <= max)
10589
assert(lits(1) >= min && lits(1) <= max)
@@ -115,15 +99,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
11599

116100
val plan = Optimize.execute(input.analyze).asInstanceOf[Project]
117101

118-
val literals = new scala.collection.mutable.ArrayBuffer[Long]
119-
plan.transformDownWithSubqueries { case subQuery =>
120-
subQuery.transformAllExpressions { case expression: Literal =>
121-
literals += expression.value.asInstanceOf[Long]
122-
expression
123-
}
124-
}
125-
assert(literals.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
126-
assert(literals.toSet.size == 1)
102+
val lits = literals[Long](plan)
103+
assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
104+
assert(lits.toSet.size == 1)
127105
}
128106

129107
test("analyzer should use consistent timestamps for different timezones") {
@@ -133,12 +111,38 @@ class ComputeCurrentTimeSuite extends PlanTest {
133111

134112
val plan = Optimize.execute(input).asInstanceOf[Project]
135113

136-
val literals = new scala.collection.mutable.ArrayBuffer[Long]
137-
plan.transformAllExpressions { case e: Literal =>
138-
literals += e.value.asInstanceOf[Long]
139-
e
140-
}
114+
val lits = literals[Long](plan)
115+
assert(lits.size === localTimestamps.size)
116+
// there are timezones with a 30 or 45 minute offset
117+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
118+
assert(offsetsFromQuarterHour.size == 1)
119+
}
141120

142-
assert(literals.size === localTimestamps.size)
121+
test("analyzer should use consistent timestamps for different timestamp functions") {
122+
val differentTimestamps = Seq(
123+
Alias(CurrentTimestamp(), "currentTimestamp")(),
124+
Alias(Now(), "now")(),
125+
Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")(),
126+
)
127+
val input = Project(differentTimestamps, LocalRelation())
128+
129+
val plan = Optimize.execute(input).asInstanceOf[Project]
130+
131+
val lits = literals[Long](plan)
132+
assert(lits.size === differentTimestamps.size)
133+
// there are timezones with a 30 or 45 minute offset
134+
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
135+
assert(offsetsFromQuarterHour.size == 1)
136+
}
137+
138+
private def literals[T](plan: LogicalPlan): Seq[T] = {
139+
val literals = new scala.collection.mutable.ArrayBuffer[T]
140+
plan.transformDownWithSubqueries { case subQuery =>
141+
subQuery.transformAllExpressions { case expression: Literal =>
142+
literals += expression.value.asInstanceOf[T]
143+
expression
144+
}
145+
}
146+
literals
143147
}
144148
}

0 commit comments

Comments
 (0)