diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 931a75a83d4ef..0275741844bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { @@ -240,7 +242,10 @@ trait HashJoin { } } -object HashJoin { +object HashJoin extends CastSupport { + + override def conf: SQLConf = SQLConf.get + /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -255,14 +260,14 @@ object HashJoin { } var keyExpr: Expression = if (keys.head.dataType != LongType) { - Cast(keys.head, LongType) + cast(keys.head, LongType) } else { keys.head } keys.tail.foreach { e => val bits = e.dataType.defaultSize * 8 keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1))) } keyExpr :: Nil } @@ -275,13 +280,13 @@ object HashJoin { // jump over keys that have a higher index value than the required key if (keys.size == 1) { assert(index == 0) - Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) + cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) } else { val shiftedBits = keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 // build the schema for unpacking the required key - Cast(BitwiseAnd( + cast(BitwiseAnd( ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), Literal(mask)), keys(index).dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 335ef257920c4..ef0a596f21104 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -239,33 +239,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) - assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: Nil) === + Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil) assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === - BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), - BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)), + BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) - assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: Nil) === + Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil) assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === BitwiseOr(ShiftLeft( - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === BitwiseOr(ShiftLeft( BitwiseOr(ShiftLeft( - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), + Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === s :: s :: s :: s :: s :: Nil)