diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 21bf926af50d..d92987887b70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -186,7 +186,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveJoinStrategyHints(conf), - ResolveHints.ResolveCoalesceHints), + new ResolveHints.ResolveCoalesceHints(conf)), Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Substitution", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 3a9c4b7392e3..d904ba3aca5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.IntegerLiteral +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -137,31 +137,101 @@ object ResolveHints { } /** - * COALESCE Hint accepts name "COALESCE" and "REPARTITION". - * Its parameter includes a partition number. + * COALESCE Hint accepts names "COALESCE", "REPARTITION", and "REPARTITION_BY_RANGE". */ - object ResolveCoalesceHints extends Rule[LogicalPlan] { - private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION") + class ResolveCoalesceHints(conf: SQLConf) extends Rule[LogicalPlan] { + + /** + * This function handles hints for "COALESCE" and "REPARTITION". + * The "COALESCE" hint only has a partition number as a parameter. The "REPARTITION" hint + * has a partition number, columns, or both of them as parameters. + */ + private def createRepartition( + shuffle: Boolean, hint: UnresolvedHint): LogicalPlan = { + val hintName = hint.name.toUpperCase(Locale.ROOT) + + def createRepartitionByExpression( + numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder]) + if (sortOrders.nonEmpty) throw new IllegalArgumentException( + s"""Invalid partitionExprs specified: $sortOrders + |For range partitioning use REPARTITION_BY_RANGE instead. + """.stripMargin) + val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) + if (invalidParams.nonEmpty) { + throw new AnalysisException(s"$hintName Hint parameter should include columns, but " + + s"${invalidParams.mkString(", ")} found") + } + RepartitionByExpression( + partitionExprs.map(_.asInstanceOf[Expression]), hint.child, numPartitions) + } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => - val hintName = h.name.toUpperCase(Locale.ROOT) - val shuffle = hintName match { - case "REPARTITION" => true - case "COALESCE" => false + hint.parameters match { + case Seq(IntegerLiteral(numPartitions)) => + Repartition(numPartitions, shuffle, hint.child) + case Seq(numPartitions: Int) => + Repartition(numPartitions, shuffle, hint.child) + // The "COALESCE" hint (shuffle = false) must have a partition number only + case _ if !shuffle => + throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter") + + case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle => + createRepartitionByExpression(numPartitions, param.tail) + case param @ Seq(numPartitions: Int, _*) if shuffle => + createRepartitionByExpression(numPartitions, param.tail) + case param @ Seq(_*) if shuffle => + createRepartitionByExpression(conf.numShufflePartitions, param) + } + } + + /** + * This function handles hints for "REPARTITION_BY_RANGE". + * The "REPARTITION_BY_RANGE" hint must have column names and a partition number is optional. + */ + private def createRepartitionByRange(hint: UnresolvedHint): RepartitionByExpression = { + val hintName = hint.name.toUpperCase(Locale.ROOT) + + def createRepartitionByExpression( + numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = { + val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute]) + if (invalidParams.nonEmpty) { + throw new AnalysisException(s"$hintName Hint parameter should include columns, but " + + s"${invalidParams.mkString(", ")} found") } - val numPartitions = h.parameters match { - case Seq(IntegerLiteral(numPartitions)) => - numPartitions - case Seq(numPartitions: Int) => - numPartitions - case _ => - throw new AnalysisException(s"$hintName Hint expects a partition number as parameter") + val sortOrder = partitionExprs.map { + case expr: SortOrder => expr + case expr: Expression => SortOrder(expr, Ascending) + } + RepartitionByExpression(sortOrder, hint.child, numPartitions) + } + + hint.parameters match { + case param @ Seq(IntegerLiteral(numPartitions), _*) => + createRepartitionByExpression(numPartitions, param.tail) + case param @ Seq(numPartitions: Int, _*) => + createRepartitionByExpression(numPartitions, param.tail) + case param @ Seq(_*) => + createRepartitionByExpression(conf.numShufflePartitions, param) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match { + case "REPARTITION" => + createRepartition(shuffle = true, hint) + case "COALESCE" => + createRepartition(shuffle = false, hint) + case "REPARTITION_BY_RANGE" => + createRepartitionByRange(hint) + case _ => plan } - Repartition(numPartitions, shuffle, h.child) } } + object ResolveCoalesceHints { + val COALESCE_HINT_NAMES: Set[String] = Set("COALESCE", "REPARTITION", "REPARTITION_BY_RANGE") + } + /** * Removes all the hints, used to remove invalid hints provided by the user. * This must be executed after all the other hint rules are executed. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 474e58a335e7..cddcddd51e38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -24,10 +24,11 @@ import org.apache.log4j.spi.LoggingEvent import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, SortOrder} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType class ResolveHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ @@ -150,24 +151,86 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")), Repartition(numPartitions = 200, shuffle = true, child = testRelation)) - val errMsgCoal = "COALESCE Hint expects a partition number as parameter" + val errMsg = "COALESCE Hint expects a partition number as a parameter" + assertAnalysisError( UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")), - Seq(errMsgCoal)) + Seq(errMsg)) assertAnalysisError( UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")), - Seq(errMsgCoal)) + Seq(errMsg)) assertAnalysisError( UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")), - Seq(errMsgCoal)) + Seq(errMsg)) - val errMsgRepa = "REPARTITION Hint expects a partition number as parameter" - assertAnalysisError( + checkAnalysis( + UnresolvedHint("RePartition", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), + RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10)) + + checkAnalysis( + UnresolvedHint("REPARTITION", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), + RepartitionByExpression(Seq(AttributeReference("a", IntegerType)()), testRelation, 10)) + + checkAnalysis( UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), - Seq(errMsgRepa)) + RepartitionByExpression( + Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions)) + + val e = intercept[IllegalArgumentException] { + checkAnalysis( + UnresolvedHint("REPARTITION", + Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), + table("TaBlE")), + RepartitionByExpression( + Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), testRelation, 10) + ) + } + e.getMessage.contains("For range partitioning use REPARTITION_BY_RANGE instead") + + checkAnalysis( + UnresolvedHint( + "REPARTITION_BY_RANGE", Seq(Literal(10), UnresolvedAttribute("a")), table("TaBlE")), + RepartitionByExpression( + Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), testRelation, 10)) + + checkAnalysis( + UnresolvedHint( + "REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")), + RepartitionByExpression( + Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)), + testRelation, conf.numShufflePartitions)) + + val errMsg2 = "REPARTITION Hint parameter should include columns, but" + assertAnalysisError( UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), - Seq(errMsgRepa)) + Seq(errMsg2)) + + assertAnalysisError( + UnresolvedHint("REPARTITION", + Seq(Literal(1.0), AttributeReference("a", IntegerType)()), + table("TaBlE")), + Seq(errMsg2)) + + val errMsg3 = "REPARTITION_BY_RANGE Hint parameter should include columns, but" + + assertAnalysisError( + UnresolvedHint("REPARTITION_BY_RANGE", + Seq(Literal(1.0), AttributeReference("a", IntegerType)()), + table("TaBlE")), + Seq(errMsg3)) + + assertAnalysisError( + UnresolvedHint("REPARTITION_BY_RANGE", + Seq(Literal(10), Literal(10)), + table("TaBlE")), + Seq(errMsg3)) + + assertAnalysisError( + UnresolvedHint("REPARTITION_BY_RANGE", + Seq(Literal(10), Literal(10), UnresolvedAttribute("a")), + table("TaBlE")), + Seq(errMsg3)) } test("log warnings for invalid hints") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8b940a7aa2c3..875096f61524 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -624,6 +624,52 @@ class PlanParserSuite extends AnalysisTest { table("t").select(star())))) intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input") + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), + UnresolvedHint("COALESCE", Seq(Literal(50)), + table("t").select(star())))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), + UnresolvedHint("BROADCASTJOIN", Seq($"u"), + UnresolvedHint("COALESCE", Seq(Literal(50)), + table("t").select(star()))))) + + comparePlans( + parsePlan( + """ + |SELECT + |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ + |* FROM t + """.stripMargin), + UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), + UnresolvedHint("BROADCASTJOIN", Seq($"u"), + UnresolvedHint("COALESCE", Seq(Literal(50)), + UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")), + table("t").select(star())))))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"), + UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"), + UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")), + table("t").select(star()))) } test("SPARK-20854: select hint syntax with expressions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index b33c26a0b75a..37dc8f1bcc7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -68,5 +68,17 @@ class DataFrameHintSuite extends AnalysisTest with SharedSparkSession { check( df.hint("REPARTITION", 100), UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan)) + + check( + df.hint("REPARTITION", 10, $"id".expr), + UnresolvedHint("REPARTITION", Seq(10, $"id".expr), df.logicalPlan)) + + check( + df.hint("REPARTITION_BY_RANGE", $"id".expr), + UnresolvedHint("REPARTITION_BY_RANGE", Seq($"id".expr), df.logicalPlan)) + + check( + df.hint("REPARTITION_BY_RANGE", 10, $"id".expr), + UnresolvedHint("REPARTITION_BY_RANGE", Seq(10, $"id".expr), df.logicalPlan)) } }