Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33832][SQL] Support optimize skewed join even if introduce extra shuffle #32816

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
49e8bd9
Support optimize skew join even if introduce extra shuffle
ulysses-you Jun 8, 2021
db77ddd
EnsureRequirements
ulysses-you Jun 8, 2021
a63cd72
make a new rules
ulysses-you Jun 9, 2021
59a5e4a
fix local reader number
ulysses-you Jun 10, 2021
9c985da
more cost
ulysses-you Jul 2, 2021
e2102a5
Merge branch 'master' of https://github.com/apache/spark into support…
ulysses-you Jul 9, 2021
8bc22ad
nit
ulysses-you Jul 9, 2021
7734d3e
nit
ulysses-you Jul 9, 2021
cbc7553
Merge branch 'master' of https://github.com/apache/spark into support…
ulysses-you Aug 3, 2021
3dc61a3
force optimize skewed join
ulysses-you Aug 3, 2021
30b7de0
style
ulysses-you Aug 3, 2021
6caa4a3
name
ulysses-you Aug 3, 2021
cd1a379
final stage
ulysses-you Aug 13, 2021
2b3bfe6
style
ulysses-you Aug 13, 2021
d305894
conflick
ulysses-you Aug 13, 2021
6725f97
checkDistribution
ulysses-you Aug 13, 2021
7a0448b
SimpleCostEvaluator
ulysses-you Aug 19, 2021
60b7b9d
address comment
ulysses-you Aug 19, 2021
fbf9727
cost
ulysses-you Aug 19, 2021
b54e9c2
plan twice
ulysses-you Aug 20, 2021
f5ad40e
nit
ulysses-you Aug 20, 2021
8058fe9
nit
ulysses-you Aug 20, 2021
369bf33
ensureRequiredDistribution
ulysses-you Aug 25, 2021
d93c3df
remove dead code
ulysses-you Aug 25, 2021
b215e2d
simplify code
ulysses-you Aug 25, 2021
5b63e4d
address comment
ulysses-you Aug 25, 2021
3ccc29b
style
ulysses-you Aug 25, 2021
bc45d70
fix order
ulysses-you Aug 25, 2021
580a0a4
address comment
ulysses-you Aug 26, 2021
bc39694
address comment
ulysses-you Aug 26, 2021
bb2e713
address comment
ulysses-you Sep 2, 2021
d3f0131
nit
ulysses-you Sep 2, 2021
4712986
nit
ulysses-you Sep 2, 2021
23ebea0
address comment
ulysses-you Sep 5, 2021
ef0765f
pass EnsureRequirements
ulysses-you Sep 7, 2021
76c363d
simplify
ulysses-you Sep 7, 2021
8961084
nit
ulysses-you Sep 7, 2021
5ba73c4
EnsureRequirements
ulysses-you Sep 7, 2021
ca63321
pull out shuffle origin
ulysses-you Sep 9, 2021
f5e4b91
address comment
ulysses-you Sep 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN =
buildConf("spark.sql.adaptive.forceOptimizeSkewedJoin")
.doc("When true, force enable OptimizeSkewedJoin even if it introduces extra shuffle.")
.version("3.3.0")
.booleanConf
.createWithDefault(false)

val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS =
buildConf("spark.sql.adaptive.customCostEvaluatorClass")
.doc("The custom cost evaluator class to be used for adaptive execution. If not being set," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,36 @@ case class AdaptiveSparkPlanExec(
AQEUtils.getRequiredDistribution(inputPlan)
}

@transient private val costEvaluator =
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
case _ => SimpleCostEvaluator(conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN))
}

// A list of physical plan rules to be applied before creation of query stages. The physical
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
@transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
RemoveRedundantProjects,
@transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = {
// For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for
// the final plan, but we do need to respect the user-specified repartition. Here we ask
// `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
// around this case.
EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan
) ++ context.session.sessionState.queryStagePrepRules
val ensureRequirements =
EnsureRequirements(requiredDistribution.isDefined, requiredDistribution)
Seq(
RemoveRedundantProjects,
ensureRequirements,
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements, costEvaluator)
) ++ context.session.sessionState.queryStagePrepRules
}

// A list of physical optimizer rules to be applied to a new stage before its execution. These
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
PlanAdaptiveDynamicPruningFilters(this),
ReuseAdaptiveSubquery(context.subqueryCache),
// Skew join does not handle `AQEShuffleRead` so needs to be applied first.
OptimizeSkewedJoin,
OptimizeSkewInRebalancePartitions,
CoalesceShufflePartitions(context.session),
// `OptimizeShuffleWithLocalRead` needs to make use of 'AQEShuffleReadExec.partitionSpecs'
Expand Down Expand Up @@ -169,12 +178,6 @@ case class AdaptiveSparkPlanExec(
optimized
}

@transient private val costEvaluator =
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
case _ => SimpleCostEvaluator
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just move this code up


@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import scala.collection.mutable
import org.apache.commons.io.FileUtils

import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

Expand All @@ -48,9 +49,10 @@ import org.apache.spark.sql.internal.SQLConf
* (L3, R3-1), (L3, R3-2),
* (L4-1, R4-1), (L4-2, R4-1), (L4-1, R4-2), (L4-2, R4-2)
*/
object OptimizeSkewedJoin extends AQEShuffleReadRule {

override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
case class OptimizeSkewedJoin(
ensureRequirements: EnsureRequirements,
costEvaluator: CostEvaluator)
extends Rule[SparkPlan] {

/**
* A partition is considered as a skewed partition if its size is larger than the median
Expand Down Expand Up @@ -250,15 +252,26 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
// SHJ
// Shuffle
// Shuffle
optimizeSkewJoin(plan)
val optimized = ensureRequirements.apply(optimizeSkewJoin(plan))
val originCost = costEvaluator.evaluateCost(plan)
val optimizedCost = costEvaluator.evaluateCost(optimized)
// two cases we will pick new plan:
// 1. optimize the skew join without extra shuffle
// 2. optimize the skew join with extra shuffle but the costEvaluator think it's better
if (optimizedCost <= originCost) {
optimized
} else {
plan
}
} else {
plan
}
}

object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) =>
case s: ShuffleQueryStageExec if s.isMaterialized && s.mapStats.isDefined &&
s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS =>
Some(s)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.joins.ShuffledJoin

/**
* A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
Expand All @@ -35,15 +36,52 @@ case class SimpleCost(value: Long) extends Cost {
}

/**
* A simple implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeLike]] nodes in the plan.
* A skew join aware implementation of [[Cost]], which consider shuffle number and skew join number.
*
* We always pick the cost which has more skew join even if it introduces one or more extra shuffle.
* Otherwise, if two costs have the same number of skew join or no skew join, we will pick the one
* with small number of shuffle.
*/
object SimpleCostEvaluator extends CostEvaluator {
case class SkewJoinAwareCost(
numShuffles: Int,
numSkewJoins: Int) extends Cost {
override def compare(that: Cost): Int = that match {
case other: SkewJoinAwareCost =>
// If more skew joins are optimized or less shuffle nodes, it means the cost is lower
if (numSkewJoins > other.numSkewJoins) {
-1
} else if (numSkewJoins < other.numSkewJoins) {
1
} else if (numShuffles < other.numShuffles) {
-1
} else if (numShuffles > other.numShuffles) {
1
} else {
0
}

case _ =>
throw QueryExecutionErrors.cannotCompareCostWithTargetCostError(that.toString)
}
}

/**
* A skew join aware implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeLike]] nodes and skew join nodes in the plan.
*/
case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends CostEvaluator {
override def evaluateCost(plan: SparkPlan): Cost = {
val cost = plan.collect {
val numShuffles = plan.collect {
case s: ShuffleExchangeLike => s
}.size
SimpleCost(cost)

if (forceOptimizeSkewedJoin) {
val numSkewJoins = plan.collect {
case j: ShuffledJoin if j.isSkewJoin => j
}.size
SkewJoinAwareCost(numShuffles, numSkewJoins)
} else {
SimpleCost(numShuffles)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,31 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
* but can be false in AQE when AQE optimization may change the plan
* output partitioning and need to retain the user-specified
* repartition shuffles in the plan.
* @param requiredDistribution The root required distribution we should ensure. This value is used
* in AQE in case we change final stage output partitioning.
*/
case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] {

private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
var children: Seq[SparkPlan] = operator.children
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)
case class EnsureRequirements(
optimizeOutRepartition: Boolean = true,
requiredDistribution: Option[Distribution] = None)
extends Rule[SparkPlan] {

private def ensureDistributionAndOrdering(
originalChildren: Seq[SparkPlan],
requiredChildDistributions: Seq[Distribution],
requiredChildOrderings: Seq[Seq[SortOrder]],
shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = {
assert(requiredChildDistributions.length == originalChildren.length)
assert(requiredChildOrderings.length == originalChildren.length)
// Ensure that the operator's children satisfy their output distribution requirements.
children = children.zip(requiredChildDistributions).map {
var newChildren = originalChildren.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
}

// Get the indexes of children which have specified distribution requirements and need to have
Expand All @@ -69,7 +74,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
}.map(_._2)

val childrenNumPartitions =
childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
childrenIndexes.map(newChildren(_).outputPartitioning.numPartitions).toSet

if (childrenNumPartitions.size > 1) {
// Get the number of partitions which is explicitly required by the distributions.
Expand All @@ -78,7 +83,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
index => requiredChildDistributions(index).requiredNumPartitions
}.toSet
assert(numPartitionsSet.size <= 1,
s"$operator have incompatible requirements of the number of partitions for its children")
s"$requiredChildDistributions have incompatible requirements of the number of partitions")
numPartitionsSet.headOption
}

Expand All @@ -87,7 +92,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
// 1. We should avoid shuffling these children.
// 2. We should have a reasonable parallelism.
val nonShuffleChildrenNumPartitions =
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
childrenIndexes.map(newChildren).filterNot(_.isInstanceOf[ShuffleExchangeExec])
.map(_.outputPartitioning.numPartitions)
val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
Expand All @@ -106,7 +111,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru

val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)

children = children.zip(requiredChildDistributions).zipWithIndex.map {
newChildren = newChildren.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, distribution), index) if childrenIndexes.contains(index) =>
if (child.outputPartitioning.numPartitions == targetNumPartitions) {
child
Expand All @@ -124,7 +129,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
}

// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
newChildren = newChildren.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
child
Expand All @@ -133,7 +138,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
}
}

operator.withNewChildren(children)
newChildren
}

private def reorder(
Expand Down Expand Up @@ -254,25 +259,50 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true
case lower: PartitioningCollection =>
lower.partitionings.exists(hasSemanticEqualPartitioning)
case _ => false
def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true
case lower: PartitioningCollection =>
lower.partitionings.exists(hasSemanticEqualPartitioning)
case _ => false
}
}
}
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
child
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
child
} else {
operator
}

case operator: SparkPlan =>
val reordered = reorderJoinPredicates(operator)
val newChildren = ensureDistributionAndOrdering(
reordered.children,
reordered.requiredChildDistribution,
reordered.requiredChildOrdering,
ENSURE_REQUIREMENTS)
reordered.withNewChildren(newChildren)
}

if (requiredDistribution.isDefined) {
val shuffleOrigin = if (requiredDistribution.get.requiredNumPartitions.isDefined) {
REPARTITION_BY_NUM
} else {
operator
REPARTITION_BY_COL
}

case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
val finalPlan = ensureDistributionAndOrdering(
newPlan :: Nil,
requiredDistribution.get :: Nil,
Seq(Nil),
shuffleOrigin)
assert(finalPlan.size == 1)
finalPlan.head
} else {
newPlan
}
}
}
Loading