Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ case class Join(
}
}

/**
* A hint for the optimizer that we should broadcast the `child` if used in a join operator.
*/
case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just have this be StatisticsHint and override the statistics? I'm afraid that we are going to forget to add cases in the future as we broadcast more type of joins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about this more and actually maybe what we want to do is have this specific node, and a pattern that recognizes canBroadcast. This pattern can check either for a small enough size or this hint and we can use that anywhere we are planning a broadcast operator. The reasoning is that messing with statistics could have other consequences.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Add specific node is also doable, probably the node can be named like Hint in LogicalPlan, and could be more than broadcast, like uniq_key etc.

Any idea?

override def output: Seq[Attribute] = child.output
}


case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
Expand Down Expand Up @@ -52,6 +52,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
object CanBroadcast {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus is this what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
case BroadcastHint(p) => Some(p)
case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p)
case _ => None
}
}

/**
* Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
* evaluated by matching hash keys.
Expand Down Expand Up @@ -80,15 +92,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
// for now let's support inner join first, then add outer join
Expand Down Expand Up @@ -329,6 +337,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
case BroadcastHint(child) => apply(child)
case _ => Nil
}
}
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -565,6 +566,22 @@ object functions {
array((colName +: colNames).map(col) : _*)
}

/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
* The following example marks the right DataFrame for broadcast hash join using `joinKey`.
* {{{
* // left and right are DataFrames
* left.join(broadcast(right), "joinKey")
* }}}
*
* @group normal_funcs
* @since 1.5.0
*/
def broadcast(df: DataFrame): DataFrame = {
DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan))
}

/**
* Returns the first column that is not null.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions._

class DataFrameJoinSuite extends QueryTest {
Expand Down Expand Up @@ -93,4 +94,20 @@ class DataFrameJoinSuite extends QueryTest {
left.join(right, left("key") === right("key")),
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)
}

test("broadcast join hint") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")

// equijoin - should be converted into broadcast join
val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan
assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1)

// no join key -- should not be a broadcast join
val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan
assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0)

// planner should not crash without a join
broadcast(df1).queryExecution.executedPlan
}
}