Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,45 +63,16 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
.map(BindReferences.bindReference(_, right.output))
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
}
}

/**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
*
* If not, returns the original expressions.
*/
private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
var keyExpr: Expression = null
var width = 0
keys.foreach { e =>
e.dataType match {
case dt: IntegralType if dt.defaultSize <= 8 - width =>
if (width == 0) {
if (e.dataType != LongType) {
keyExpr = Cast(e, LongType)
} else {
keyExpr = e
}
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
case other =>
return keys
}
}
keyExpr :: Nil
}


protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(buildKeys)
Expand Down Expand Up @@ -247,3 +218,31 @@ trait HashJoin {
}
}
}

object HashJoin {
/**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
*
* If not, returns the original expressions.
*/
private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be a lot clear and less error prone if we can write the check first and go through the data twice, e.g.

if (keys.map(_.dataType.defaultSize).sum <= 8) {
  return keys
}

// do the rewrite here

assert(keys.nonEmpty)
// TODO: support BooleanType, DateType and TimestampType
if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
|| keys.map(_.dataType.defaultSize).sum > 8) {
return keys
}

var keyExpr: Expression = if (keys.head.dataType != LongType) {
Cast(keys.head, LongType)
} else {
keys.head
}
keys.tail.foreach { e =>
val bits = e.dataType.defaultSize * 8
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
}
keyExpr :: Nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import scala.reflect.ClassTag

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{LongType, ShortType}

/**
* Test various broadcast join operators.
Expand Down Expand Up @@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
cases.foreach(assertBroadcastJoin)
}
}

test("join key rewritten") {
val l = Literal(1L)
val i = Literal(2)
val s = Literal.create(3, ShortType)
val ss = Literal("hello")

assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)

assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)

assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
s :: s :: s :: s :: s :: Nil)

assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil)
assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
}
}