Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
new ResolveHints.ResolveJoinStrategyHints(conf),
ResolveHints.ResolveCoalesceHints),
new ResolveHints.ResolveCoalesceHints(conf)),
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Substitution", fixedPoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
Expand Down Expand Up @@ -137,31 +137,101 @@ object ResolveHints {
}

/**
* COALESCE Hint accepts name "COALESCE" and "REPARTITION".
* Its parameter includes a partition number.
* COALESCE Hint accepts names "COALESCE", "REPARTITION", and "REPARTITION_BY_RANGE".
*/
object ResolveCoalesceHints extends Rule[LogicalPlan] {
private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION")
class ResolveCoalesceHints(conf: SQLConf) extends Rule[LogicalPlan] {

/**
* This function handles hints for "COALESCE" and "REPARTITION".
* The "COALESCE" hint only has a partition number as a parameter. The "REPARTITION" hint
* has a partition number, columns, or both of them as parameters.
*/
private def createRepartition(
shuffle: Boolean, hint: UnresolvedHint): LogicalPlan = {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
Copy link
Contributor

@hvanhovell hvanhovell Feb 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style, please put this inside curly braces and on a new line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ulysses-you Could you do follow-up for the two comments from @hvanhovell ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it might be better to just handle this later when we happen to touch this codes given that we don't usually make followups for minor styles issues.

s"""Invalid partitionExprs specified: $sortOrders
|For range partitioning use REPARTITION_BY_RANGE instead.
""".stripMargin)
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check breaks the old API, in Spark 2.4 it is possible to use an expression here. I think we need to back this out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK 2.4 only supports something like REPARTITION(5), the parameters here means anything after the partition number parameter, e.g. REPARTITION(5, para1, para2, ...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think so, too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, you are right. Never the less I think we should support expressions for REPARTITION here.

if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
s"${invalidParams.mkString(", ")} found")
}
RepartitionByExpression(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we then consistently throw an exception like Dataset.repartition?

    val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder])
    if (sortOrders.nonEmpty) throw new IllegalArgumentException(
      s"""Invalid partitionExprs specified: $sortOrders
         |For range partitioning use repartitionByRange(...) instead.
       """.stripMargin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, add an IllegalArgumentException check.

partitionExprs.map(_.asInstanceOf[Expression]), hint.child, numPartitions)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
val hintName = h.name.toUpperCase(Locale.ROOT)
val shuffle = hintName match {
case "REPARTITION" => true
case "COALESCE" => false
hint.parameters match {
case Seq(IntegerLiteral(numPartitions)) =>
Repartition(numPartitions, shuffle, hint.child)
case Seq(numPartitions: Int) =>
Repartition(numPartitions, shuffle, hint.child)
// The "COALESCE" hint (shuffle = false) must have a partition number only
case _ if !shuffle =>
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, It's REPARTITION() but why does it creates range partition in this case? Do you intend to support range partition by something like REPARTITION(...)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried to keep consistent with Dataset.repartition(). The latter has two methods repartition(numPartitions: Int) and repartition(numPartitions: Int, partitionExprs: Column*)

case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(conf.numShufflePartitions, param)
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the case, SELECT /*+ REPARTITION(a) */ * FROM t?

/**
* This function handles hints for "REPARTITION_BY_RANGE".
* The "REPARTITION_BY_RANGE" hint must have column names and a partition number is optional.
*/
private def createRepartitionByRange(hint: UnresolvedHint): RepartitionByExpression = {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createRepartitionByExpression seems duplicated. Can we just make one private function to share?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated is in order to make method clearly for repartition(...) and repartitionByRange(...).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, both inner function and duplication are discouraged but okay.

numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
s"${invalidParams.mkString(", ")} found")
}
val numPartitions = h.parameters match {
case Seq(IntegerLiteral(numPartitions)) =>
numPartitions
case Seq(numPartitions: Int) =>
numPartitions
case _ =>
throw new AnalysisException(s"$hintName Hint expects a partition number as parameter")
val sortOrder = partitionExprs.map {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this hint cannot accept all the Exception...

}
RepartitionByExpression(sortOrder, hint.child, numPartitions)
}

hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(numPartitions, param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(numPartitions, param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(conf.numShufflePartitions, param)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match {
case "REPARTITION" =>
createRepartition(shuffle = true, hint)
case "COALESCE" =>
createRepartition(shuffle = false, hint)
case "REPARTITION_BY_RANGE" =>
createRepartitionByRange(hint)
case _ => plan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we return hint here? This will cause stack overflow once the hint is not the root node.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. I will make a followup. Thanks for catching this.

}
Repartition(numPartitions, shuffle, h.child)
}
}

object ResolveCoalesceHints {
val COALESCE_HINT_NAMES: Set[String] = Set("COALESCE", "REPARTITION", "REPARTITION_BY_RANGE")
}

/**
* Removes all the hints, used to remove invalid hints provided by the user.
* This must be executed after all the other hint rules are executed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import org.apache.log4j.spi.LoggingEvent

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, SortOrder}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class ResolveHintsSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
Expand Down Expand Up @@ -150,24 +151,86 @@ class ResolveHintsSuite extends AnalysisTest {
UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")),
Repartition(numPartitions = 200, shuffle = true, child = testRelation))

val errMsgCoal = "COALESCE Hint expects a partition number as parameter"
val errMsg = "COALESCE Hint expects a partition number as a parameter"

assertAnalysisError(
UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))
assertAnalysisError(
UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))
assertAnalysisError(
UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")),
Seq(errMsgCoal))
Seq(errMsg))

val errMsgRepa = "REPARTITION Hint expects a partition number as parameter"
assertAnalysisError(
checkAnalysis(
UnresolvedHint("RePartition", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10))

checkAnalysis(
UnresolvedHint("REPARTITION", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10))

checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
Seq(errMsgRepa))
RepartitionByExpression(
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))

val e = intercept[IllegalArgumentException] {
checkAnalysis(
UnresolvedHint("REPARTITION",
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), testRelation, 10)
)
}
e.getMessage.contains("For range partitioning use REPARTITION_BY_RANGE instead")

checkAnalysis(
UnresolvedHint(
"REPARTITION_BY_RANGE", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), testRelation, 10))

checkAnalysis(
UnresolvedHint(
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
testRelation, conf.numShufflePartitions))

val errMsg2 = "REPARTITION Hint parameter should include columns, but"

assertAnalysisError(
UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")),
Seq(errMsgRepa))
Seq(errMsg2))

assertAnalysisError(
UnresolvedHint("REPARTITION",
Seq(Literal(1.0), AttributeReference("a", IntegerType)()),
table("TaBlE")),
Seq(errMsg2))

val errMsg3 = "REPARTITION_BY_RANGE Hint parameter should include columns, but"

assertAnalysisError(
UnresolvedHint("REPARTITION_BY_RANGE",
Seq(Literal(1.0), AttributeReference("a", IntegerType)()),
table("TaBlE")),
Seq(errMsg3))

assertAnalysisError(
UnresolvedHint("REPARTITION_BY_RANGE",
Seq(Literal(10), Literal(10)),
table("TaBlE")),
Seq(errMsg3))

assertAnalysisError(
UnresolvedHint("REPARTITION_BY_RANGE",
Seq(Literal(10), Literal(10), UnresolvedAttribute("a")),
table("TaBlE")),
Seq(errMsg3))
}

test("log warnings for invalid hints") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,52 @@ class PlanParserSuite extends AnalysisTest {
table("t").select(star()))))

intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input")

comparePlans(
parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")),
table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star()))))

comparePlans(
parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star())))))

comparePlans(
parsePlan(
"""
|SELECT
|/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
|* FROM t
""".stripMargin),
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")),
table("t").select(star()))))))

comparePlans(
parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"),
UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")),
table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"),
UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))
}

test("SPARK-20854: select hint syntax with expressions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,17 @@ class DataFrameHintSuite extends AnalysisTest with SharedSparkSession {
check(
df.hint("REPARTITION", 100),
UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan))

check(
df.hint("REPARTITION", 10, $"id".expr),
UnresolvedHint("REPARTITION", Seq(10, $"id".expr), df.logicalPlan))

check(
df.hint("REPARTITION_BY_RANGE", $"id".expr),
UnresolvedHint("REPARTITION_BY_RANGE", Seq($"id".expr), df.logicalPlan))

check(
df.hint("REPARTITION_BY_RANGE", 10, $"id".expr),
UnresolvedHint("REPARTITION_BY_RANGE", Seq(10, $"id".expr), df.logicalPlan))
}
}