Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand All @@ -161,8 +162,6 @@ class Analyzer(
HandleNullInputsForUDF),
Batch("FixNullability", Once,
FixNullability),
Batch("ResolveTimeZone", Once,
ResolveTimeZone),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
Expand Down Expand Up @@ -2347,23 +2346,6 @@ class Analyzer(
}
}
}

/**
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
* time zone.
*/
object ResolveTimeZone extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
// the types between the value expression and list query expression of IN expression.
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
// information for time zone aware expressions.
case e: ListQuery => e.withNewPlan(apply(e.plan))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
Expand Down Expand Up @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
Cast(e, targetType)
cast(e, targetType)
}
castedExpr.transform {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
}.eval()
castedExpr.eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are nested expressions which are timezone aware, I think we still need to attach time zone to them?

Copy link
Member

@ueshin ueshin Apr 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess now that TimeZoneAwareExpression is resolved if it has timeZoneId, so we don't need to transform children.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right. I saw the changes to TimeZoneAwareExpression. :-)

} catch {
case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

/**
* Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
* time zone.
*/
case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
// the types between the value expression and list query expression of IN expression.
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
// information for time zone aware expressions.
case e: ListQuery => e.withNewPlan(apply(e.plan))
}

override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressions(transformTimeZoneExprs)

def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
}

/**
* Mix-in trait for constructing valid [[Cast]] expressions.
*/
trait CastSupport {
/**
* Configuration used to create a valid cast expression.
*/
def conf: SQLConf

/**
* Create a Cast expression with the session local time zone.
*/
def cast(child: Expression, dataType: DataType): Cast = {
Cast(child, dataType, Option(conf.sessionLocalTimeZone))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf
* This should be only done after the batch of Resolution, because the view attributes are not
* completely resolved during the batch of Resolution.
*/
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
Expand Down Expand Up @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n")
} else {
Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata))
}
case (_, originAttr) => originAttr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand All @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
* Common base class for time zone aware expressions.
*/
trait TimeZoneAwareExpression extends Expression {
/** The expression is only resolved when the time zone has been set. */
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation may have some custom resolution logic, how about super.resolved && timeZoneId.isDefined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this. Scala does not allow you to call super.resolved in resolved :(....


/** the timezone ID to be used to evaluate value. */
def timeZoneId: Option[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}

/**
Expand Down Expand Up @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val converted = ResolveInlineTables(conf).convert(table)
val withTimeZone = ResolveTimeZone(conf).apply(table)
val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(converted.output.map(_.dataType) == Seq(TimestampType))
assert(converted.data.size == 1)
assert(converted.data(0).getLong(0) == correct)
assert(output.map(_.dataType) == Seq(TimestampType))
assert(data.size == 1)
assert(data.head.getLong(0) == correct)
}

test("nullability inference in convert") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest {
}
}

private val timeZoneResolver = ResolveTimeZone(new SQLConf)

private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
}

test("WidenSetOperationTypes for except and intersect") {
val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
Expand All @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())

val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
Expand Down Expand Up @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("p", ByteType)(),
AttributeReference("q", DoubleType)())

val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val unionRelation = wt(
val unionRelation = widenSetOperationTypes(
Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
assert(unionRelation.children.length == 4)
checkOutput(unionRelation.children.head, expectedTypes)
Expand All @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest {
}
}

val dp = TypeCoercion.WidenSetOperationTypes

val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(10, 8))

val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]

checkOutput(r1.children.head, expectedType1)
checkOutput(r1.children.last, expectedType1)
Expand All @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest {
val plan2 = LocalRelation(
AttributeReference("r", rType)())

val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]

checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))

val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]

checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String
*/
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {

private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = {
v match {
case lit: Expression => Cast(lit, targetType, timeZoneId)
case _ => Cast(Literal(v), targetType, timeZoneId)
Expand All @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}

private def checkNullCast(from: DataType, to: DataType): Unit = {
checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
checkEvaluation(cast(Literal.create(null, from), to), null)
}

test("null cast") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Seconds") {
assert(Second(Literal.create(null, DateType), gmtId).resolved === false)
assert(Second(Cast(Literal(d), TimestampType), None).resolved === true)
assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true)
checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15)
checkEvaluation(Second(Literal(ts), gmtId), 15)
Expand Down Expand Up @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Hour") {
assert(Hour(Literal.create(null, DateType), gmtId).resolved === false)
assert(Hour(Literal(ts), None).resolved === true)
assert(Hour(Literal(ts), gmtId).resolved === true)
checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13)
checkEvaluation(Hour(Literal(ts), gmtId), 13)
Expand All @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Minute") {
assert(Minute(Literal.create(null, DateType), gmtId).resolved === false)
assert(Minute(Literal(ts), None).resolved === true)
assert(Minute(Literal(ts), gmtId).resolved === true)
checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(
Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand All @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val expr: Expression = serializer.deserialize(serializer.serialize(expression))
val resolver = ResolveTimeZone(new SQLConf)
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SparkPlanner(
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
JoinSelection ::
Expand Down
Loading