Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
BroadcastNestedLoop ::
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a the BroadcastNestedLoop, which should be prior to CartesianProduct.

CartesianProduct ::
BroadcastNestedLoopJoin :: Nil)
DefaultJoin :: Nil)

/**
* Used to build table scan operators where complex projection and filtering are done using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}


object BroadcastNestedLoopJoin extends Strategy {
object BroadcastNestedLoop extends Strategy {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renaming is intended, as some of the users will extends the SQLContext, they may not ware the changing here.

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
joins.BuildRight
} else {
joins.BuildLeft
}
joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case logical.Join(
CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin =>
execution.joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
case logical.Join(
left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin =>
execution.joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
case _ => Nil
}
}

object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, _, None) =>
// TODO CartesianProduct doesn't support the Left Semi Join
case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin =>
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
Expand All @@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object DefaultJoin extends Strategy {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the new rule object, for the real last gate, for now, it actually will generate the BroadcastNestedLoopJoin operator, as what we have today, probably we can change that in the future.

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
joins.BuildRight
} else {
joins.BuildLeft
}
joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
}
}

protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)

object TakeOrderedAndProject extends Strategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
case x =>
case Inner =>
// TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case
left.output ++ right.output
case x => // TODO support the Left Semi Join
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
Expand Down
92 changes: 92 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {

setupTestData()

def statisticSizeInByte(df: DataFrame): BigInt = {
df.queryExecution.optimizedPlan.statistics.sizeInBytes
}

test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
Expand Down Expand Up @@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext {
sql("UNCACHE TABLE testData")
}

test("cross join with broadcast") {
sql("CACHE TABLE testData")

val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData"))

// we set the threshold is greater than statistic of the cached table testData
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {

assert(statisticSizeInByte(sqlContext.table("testData2")) >
sqlContext.conf.autoBroadcastJoinThreshold)

assert(statisticSizeInByte(sqlContext.table("testData")) <
sqlContext.conf.autoBroadcastJoinThreshold)

Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2",
classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData LEFT JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData RIGHT JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData FULL OUTER JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2 WHERE key > a",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }

checkAnswer(
sql(
"""
SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2
""".stripMargin),
Row("2", 1, 1) ::
Row("2", 1, 2) ::
Row("2", 2, 1) ::
Row("2", 2, 2) ::
Row("2", 3, 1) ::
Row("2", 3, 2) :: Nil)

checkAnswer(
sql(
"""
SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a
""".stripMargin),
Row("1", 2, 1) ::
Row("1", 2, 2) ::
Row("1", 3, 1) ::
Row("1", 3, 2) ::
Row("2", 3, 1) ::
Row("2", 3, 2) :: Nil)

checkAnswer(
sql(
"""
SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a
""".stripMargin),
Row("1", 2, 1) ::
Row("1", 2, 2) ::
Row("1", 3, 1) ::
Row("1", 3, 2) ::
Row("2", 3, 1) ::
Row("2", 3, 2) :: Nil)
}

sql("UNCACHE TABLE testData")
}

test("left semi join") {
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,9 @@ class HiveContext private[hive](
LeftSemiJoin,
EquiJoinSelection,
BasicOperators,
BroadcastNestedLoop,
CartesianProduct,
BroadcastNestedLoopJoin
DefaultJoin
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
302 0
302 0
302 0
305 0
305 0
305 0
306 0
306 0
306 0
307 0
307 0
307 0
307 0
307 0
307 0
308 0
308 0
308 0
309 0
309 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
302 0
302 0
302 0
305 0
305 0
305 0
305 2
305 4
306 0
306 0
306 0
306 2
306 4
306 5
306 5
306 5
307 0
307 0
307 0
307 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
302 0
302 0
302 0
305 0
305 0
305 0
305 2
305 4
306 0
306 0
306 0
306 2
306 4
306 5
306 5
306 5
307 0
307 0
307 0
307 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
302 0
302 0
302 0
305 0
305 0
305 0
305 2
305 4
306 0
306 0
306 0
306 2
306 4
306 5
306 5
306 5
307 0
307 0
307 0
307 0
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.execution
import java.io.File
import java.util.{Locale, TimeZone}

import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin

import scala.util.Try

import org.scalatest.BeforeAndAfter
Expand Down Expand Up @@ -69,6 +71,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}

// Testing the Broadcast based join for cartesian join (cross join)
// We assume that the Broadcast Join Threshold will works since the src is a small table
private val spark_10484_1 = """
| SELECT a.key, b.key
| FROM src a LEFT JOIN src b WHERE a.key > b.key + 300
| ORDER BY b.key, a.key
| LIMIT 20
""".stripMargin
private val spark_10484_2 = """
| SELECT a.key, b.key
| FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300
| ORDER BY a.key, b.key
| LIMIT 20
""".stripMargin
private val spark_10484_3 = """
| SELECT a.key, b.key
| FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300
| ORDER BY a.key, b.key
| LIMIT 20
""".stripMargin
private val spark_10484_4 = """
| SELECT a.key, b.key
| FROM src a JOIN src b WHERE a.key > b.key + 300
| ORDER BY a.key, b.key
| LIMIT 20
""".stripMargin

createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1",
spark_10484_1)

createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2",
spark_10484_2)

createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3",
spark_10484_3)

createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4",
spark_10484_4)

test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") {
def assertBroadcastNestedLoopJoin(sqlText: String): Unit = {
assert(sql(sqlText).queryExecution.sparkPlan.collect {
case _: BroadcastNestedLoopJoin => 1
}.nonEmpty)
}

assertBroadcastNestedLoopJoin(spark_10484_1)
assertBroadcastNestedLoopJoin(spark_10484_2)
assertBroadcastNestedLoopJoin(spark_10484_3)
assertBroadcastNestedLoopJoin(spark_10484_4)
}

createQueryTest("SPARK-8976 Wrong Result for Rollup #1",
"""
SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP
Expand Down