diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3c531323397e4..c062e4e84bcdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -373,8 +373,8 @@ package object dsl { def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) - def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan = - RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n)) + def distribute(exprs: Expression*)(n: Int): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan, numPartitions = n) def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0c13e3e93a42c..af846a09a8d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -578,7 +578,7 @@ object CollapseRepartition extends Rule[LogicalPlan] { RepartitionByExpression(exprs, child, numPartitions) // Case 3 case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = Some(numPartitions)) + r.copy(numPartitions = numPartitions) // Case 3 case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => RepartitionByExpression(exprs, child, numPartitions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 08a6dd136b857..926a37b363f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Sort(sort.asScala.map(visitSortItem), global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) + withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... DISTRIBUTE BY ... Sort( sort.asScala.map(visitSortItem), global = false, - RepartitionByExpression(expressionList(distributeBy), query)) + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { // CLUSTER BY ... val expressions = expressionList(clusterBy) Sort( expressions.map(SortOrder(_, Ascending)), global = false, - RepartitionByExpression(expressions, query)) + withRepartitionByExpression(ctx, expressions, query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // [EMPTY] query @@ -273,6 +273,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + /** * Create a logical plan using a query specification. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index af57632516790..d17d12cd83091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -844,18 +844,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * information about the number of partitions during execution. Used when a specific ordering or * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. - * If `numPartitions` is not specified, the number of partitions will be the number set by - * `spark.sql.shuffle.partitions`. */ case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Option[Int] = None) extends UnaryNode { + numPartitions: Int) extends UnaryNode { - numPartitions match { - case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") - case None => // Ok - } + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 786e0f49b4b25..01737e0a17341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,11 +21,12 @@ import java.util.TimeZone import org.scalatest.ShouldMatchers -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, Inner} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("pull out nondeterministic expressions from RepartitionByExpression") { - val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10) val projected = Alias(Rand(33), "_nondeterministic")() val expected = Project(testRelation.output, RepartitionByExpression(Seq(projected.toAttribute), - Project(testRelation.output :+ projected, testRelation))) + Project(testRelation.output :+ projected, testRelation), + numPartitions = 10)) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 2c1425242620e..67d5d2202b680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest { val orderSortDistrClusterClauses = Seq( ("", basePlan), (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)()), - (" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc)) + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)) ) orderSortDistrClusterClauses.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 38a24cc8ed8c2..1ebc53d2bb84e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2410,7 +2410,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } /** @@ -2425,7 +2425,8 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + RepartitionByExpression( + partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d50800235264f..1340aebc1ddd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -22,16 +22,17 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -1441,4 +1442,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { reader, writer, schemaLess) } + + /** + * Create a clause for DISTRIBUTE BY. + */ + override protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, conf.numShufflePartitions) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 557181ebd9590..0e3d5595df94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -332,8 +332,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions: Int = self.numPartitions - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil @@ -414,9 +412,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => execution.RangeExec(r) :: Nil - case logical.RepartitionByExpression(expressions, child, nPartitions) => + case logical.RepartitionByExpression(expressions, child, numPartitions) => exchange.ShuffleExchange(HashPartitioning( - expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil + expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 15e490fb30a27..bb6c486e880a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -36,7 +38,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType */ class SparkSqlParserSuite extends PlanTest { - private lazy val parser = new SparkSqlParser(new SQLConf) + val newConf = new SQLConf + private lazy val parser = new SparkSqlParser(newConf) /** * Normalizes plans: @@ -251,4 +254,29 @@ class SparkSqlParserSuite extends PlanTest { assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("t"))) + + assertEqual(s"$baseSql distribute by a, b", + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions)) + assertEqual(s"$baseSql distribute by a sort by b", + Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + assertEqual(s"$baseSql cluster by a, b", + Sort(SortOrder(UnresolvedAttribute("a"), Ascending) :: + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + global = false, + RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil, + basePlan, + numPartitions = newConf.numShufflePartitions))) + } }