Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ object ResolveHints {
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
ResolvedHint(plan, isBroadcastable = Option(true))
ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
ResolvedHint(plan, isBroadcastable = Option(true))
ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))

case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
Expand Down Expand Up @@ -88,7 +88,7 @@ object ResolveHints {
case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
ResolvedHint(h.child, isBroadcastable = Option(true))
ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true)))
} else {
// Otherwise, find within the subtree query plans that should be broadcasted.
applyBroadcastHint(h.child, h.parameters.toSet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ abstract class UnaryNode extends LogicalPlan {
}

// Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ import org.apache.spark.util.Utils
* defaults to the product of children's `sizeInBytes`.
* @param rowCount Estimated number of rows.
* @param attributeStats Statistics for Attributes.
* @param isBroadcastable If true, output is small enough to be used in a broadcast join.
* @param hints Query hints.
*/
case class Statistics(
sizeInBytes: BigInt,
rowCount: Option[BigInt] = None,
attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
isBroadcastable: Boolean = false) {
hints: HintInfo = HintInfo()) {

override def toString: String = "Statistics(" + simpleString + ")"

Expand All @@ -65,14 +65,9 @@ case class Statistics(
} else {
""
},
s"isBroadcastable=$isBroadcastable"
s"hints=$hints"
).filter(_.nonEmpty).mkString(", ")
}

/** Must be called when computing stats for a join operator to reset hints. */
def resetHintsForJoin(): Statistics = copy(
isBroadcastable = false
)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
val leftSize = left.stats(conf).sizeInBytes
val rightSize = right.stats(conf).sizeInBytes
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable

Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
Statistics(
sizeInBytes = sizeInBytes,
hints = left.stats(conf).hints.resetForJoin())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't propagate isBroadcastable Hints in Intersect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually no-op since Intersect is rewritten to a join always ..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, right.

}
}

Expand Down Expand Up @@ -364,7 +364,8 @@ case class Join(
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
super.computeStats(conf).resetHintsForJoin()
val stats = super.computeStats(conf)
stats.copy(hints = stats.hints.resetForJoin())
}

if (conf.cboEnabled) {
Expand Down Expand Up @@ -560,7 +561,7 @@ case class Aggregate(
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
isBroadcastable = child.stats(conf).isBroadcastable)
hints = child.stats(conf).hints)
} else {
super.computeStats(conf)
}
Expand Down Expand Up @@ -749,7 +750,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
rowCount = Some(rowCount),
isBroadcastable = childStats.isBroadcastable)
hints = childStats.hints)
}
}

Expand All @@ -770,7 +771,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
Statistics(
sizeInBytes = 1,
rowCount = Some(0),
isBroadcastable = childStats.isBroadcastable)
hints = childStats.hints)
} else {
// The output row count of LocalLimit should be the sum of row counts from each partition.
// However, since the number of partitions is not available here, we just use statistics of
Expand Down Expand Up @@ -827,7 +828,7 @@ case class Sample(
}
val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
// Don't propagate column stats, because we don't know the distribution after a sample operation
Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints)
}

override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,31 @@ case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalP
/**
* A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]].
*/
case class ResolvedHint(
child: LogicalPlan,
isBroadcastable: Option[Boolean] = None)
case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
extends UnaryNode {

override def output: Seq[Attribute] = child.output

override def computeStats(conf: SQLConf): Statistics = {
val stats = child.stats(conf)
isBroadcastable.map(x => stats.copy(isBroadcastable = x)).getOrElse(stats)
stats.copy(hints = hints)
}
}


case class HintInfo(
isBroadcastable: Option[Boolean] = None) {

/** Must be called when computing stats for a join operator to reset hints. */
def resetForJoin(): HintInfo = copy(
isBroadcastable = None
)

override def toString: String = {
if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) {
"none"
} else {
isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object AggregateEstimation {
sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable))
hints = childStats.hints))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ class ResolveHintsSuite extends AnalysisTest {
test("case-sensitive or insensitive parameters") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = true)

checkAnalysis(
Expand All @@ -58,28 +58,28 @@ class ResolveHintsSuite extends AnalysisTest {
test("multiple broadcast hint aliases") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
Join(ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation2, isBroadcastable = Option(true)), Inner, None),
Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None),
caseSensitive = false)
}

test("do not traverse past existing broadcast hints") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"),
ResolvedHint(table("table").where('a > 1), isBroadcastable = Option(true))),
ResolvedHint(testRelation.where('a > 1), isBroadcastable = Option(true)).analyze,
ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))),
ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze,
caseSensitive = false)
}

test("should work for subqueries") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
ResolvedHint(testRelation, isBroadcastable = Option(true)),
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
caseSensitive = false)

// Negative case: if the alias doesn't match, don't match the original table name.
Expand All @@ -104,7 +104,7 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
ResolvedHint(testRelation.where('a > 1).select('a), isBroadcastable = Option(true))
ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true)))
.select('a).analyze,
caseSensitive = false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,20 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {

test("BroadcastHint estimation") {
val filter = Filter(Literal(true), plan)
val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4),
rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4))
checkStats(
filter,
expectedStatsCboOn = filterStatsCboOn,
expectedStatsCboOff = filterStatsCboOff)

val broadcastHint = ResolvedHint(filter, isBroadcastable = Option(true))
val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true)))
checkStats(
broadcastHint,
expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))),
expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true)))
)
}

test("limit estimation: limit < child's rowCount") {
Expand Down Expand Up @@ -94,15 +95,13 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
sizeInBytes = 40,
rowCount = Some(10),
attributeStats = AttributeMap(Seq(
AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
isBroadcastable = false)
AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))))
val expectedCboStats =
Statistics(
sizeInBytes = 4,
rowCount = Some(1),
attributeStats = AttributeMap(Seq(
AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
isBroadcastable = false)
AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))))

val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
checkStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Matches a plan whose output should be small enough to be used in broadcast join.
*/
private def canBroadcast(plan: LogicalPlan): Boolean = {
plan.stats(conf).isBroadcastable ||
plan.stats(conf).hints.isBroadcastable.getOrElse(false) ||
(plan.stats(conf).sizeInBytes >= 0 &&
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold)
}
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.ResolvedHint
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1020,7 +1020,7 @@ object functions {
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
Dataset[T](df.sparkSession,
ResolvedHint(df.logicalPlan, isBroadcastable = Option(true)))(df.exprEnc)
ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
numbers.foreach { case (input, (expectedSize, expectedRows)) =>
val stats = Statistics(sizeInBytes = input, rowCount = Some(input))
val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," +
s" isBroadcastable=${stats.isBroadcastable}"
s" hints=none"
assert(stats.simpleString == expectedString)
}
}
Expand Down