From 974fadb748488c1375d8ba93b6b4b3290b0e02e2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Oct 2016 10:20:07 -0700 Subject: [PATCH 1/2] fix join key rewritten --- .../spark/sql/execution/joins/HashJoin.scala | 71 ++++++++++--------- .../execution/joins/BroadcastJoinSuite.scala | 47 ++++++++++++ 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb6bfa7b2735..5d776d62b62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -63,45 +63,16 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = HashJoin.rewriteKeyExpr(rightKeys) + .map(BindReferences.bindReference(_, right.output)) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } - /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ - private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width -= bits - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } - } - keyExpr :: Nil - } + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(buildKeys) @@ -247,3 +218,37 @@ trait HashJoin { } } } + +object HashJoin { + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width += dt.defaultSize + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 97adffa8ce10..83db81ea3f1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,11 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{LongType, ShortType} /** * Test various broadcast join operators. @@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { cases.foreach(assertBroadcastJoin) } } + + test("join key rewritten") { + val l = Literal(1L) + val i = Literal(2) + val s = Literal.create(3, ShortType) + val ss = Literal("hello") + + assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === + BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), + BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === + s :: s :: s :: s :: s :: Nil) + + assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) + } } From 87f6b24ac7dc34f1f3830cb98400d81161fc7be9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 7 Oct 2016 12:12:56 -0700 Subject: [PATCH 2/2] CR --- .../spark/sql/execution/joins/HashJoin.scala | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5d776d62b62a..275f0f7863ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -226,28 +226,22 @@ object HashJoin { * If not, returns the original expressions. */ private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width += dt.defaultSize - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } + assert(keys.nonEmpty) + // TODO: support BooleanType, DateType and TimestampType + if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) + || keys.map(_.dataType.defaultSize).sum > 8) { + return keys + } + + var keyExpr: Expression = if (keys.head.dataType != LongType) { + Cast(keys.head, LongType) + } else { + keys.head + } + keys.tail.foreach { e => + val bits = e.dataType.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) } keyExpr :: Nil }