Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals,
EliminateLazyExpression),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Expand Down Expand Up @@ -1975,24 +1974,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
withPosition(ordinal) {
if (index > 0 && index <= aggs.size) {
val ordinalExpr = aggs(index - 1)

if (ordinalExpr.exists(_.isInstanceOf[AggregateExpression])) {
throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError(
index, ordinalExpr)
} else {
trimAliases(ordinalExpr) match {
// HACK ALERT: If the ordinal expression is also an integer literal, don't use it
// but still keep the ordinal literal. The reason is we may repeatedly
// analyze the plan. Using a different integer literal may lead to
// a repeat GROUP BY ordinal resolution which is wrong. GROUP BY
// constant is meaningless so whatever value does not matter here.
// TODO: (SPARK-45932) GROUP BY ordinal should pull out grouping expressions to
// a Project, then the resolved ordinal expression is always
// `AttributeReference`.
case Literal(_: Int, IntegerType) =>
Literal(index)
case _ => ordinalExpr
}
}

ordinalExpr
} else {
throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE}
Expand Down Expand Up @@ -129,27 +129,9 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S
groupExprs: Seq[Expression]): Seq[Expression] = {
assert(selectList.forall(_.resolved))
if (isGroupByAll(groupExprs)) {
val expandedGroupExprs = expandGroupByAll(selectList)
if (expandedGroupExprs.isEmpty) {
// Don't replace the ALL when we fail to infer the grouping columns. We will eventually
// tell the user in checkAnalysis that we cannot resolve the all in group by.
groupExprs
} else {
// This is a valid GROUP BY ALL aggregate.
expandedGroupExprs.get.zipWithIndex.map { case (expr, index) =>
trimAliases(expr) match {
// HACK ALERT: If the expanded grouping expression is an integer literal, don't use it
// but use an integer literal of the index. The reason is we may repeatedly
// analyze the plan, and the original integer literal may cause failures
// with a later GROUP BY ordinal resolution. GROUP BY constant is
// meaningless so whatever value does not matter here.
case IntegerLiteral(_) =>
// GROUP BY ordinal uses 1-based index.
Literal(index + 1)
case _ => expr
}
}
}
// Don't replace the ALL when we fail to infer the grouping columns. We will eventually tell
// the user in checkAnalysis that we cannot resolve the all in group by.
expandGroupByAll(selectList).getOrElse(groupExprs)
} else {
groupExprs
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.language.implicitConversions

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -64,7 +65,7 @@ import org.apache.spark.unsafe.types.UTF8String
* LocalRelation [key#2,value#3], []
* }}}
*/
package object dsl {
package object dsl extends SQLConfHelper {
trait ImplicitOperators {
def expr: Expression

Expand Down Expand Up @@ -446,11 +447,16 @@ package object dsl {
def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan)

def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = {
// Replace top-level integer literals with ordinals, if `groupByOrdinal` is enabled.
val groupingExpressionsWithOrdinals = groupingExprs.map {
case Literal(value: Int, IntegerType) if conf.groupByOrdinal => UnresolvedOrdinal(value)
case other => other
}
val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => UnresolvedAlias(e)
}
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
Aggregate(groupingExpressionsWithOrdinals, aliasedExprs, logicalPlan)
}

def having(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.parser

import java.util.Locale
import java.util.{List, Locale}
import java.util.concurrent.TimeUnit

import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Set}
Expand Down Expand Up @@ -1286,17 +1286,17 @@ class AstBuilder extends DataTypeAstBuilder
val withOrder = if (
!order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.orderByClause
Sort(order.asScala.map(visitSortItem).toSeq, global = true, query)
Sort(order.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = true, query)
} else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.sortByClause
Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query)
Sort(sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq, global = false, query)
} else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.distributeByClause
withRepartitionByExpression(ctx, expressionList(distributeBy), query)
} else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
clause = PipeOperators.sortByDistributeByClause
Sort(
sort.asScala.map(visitSortItem).toSeq,
sort.asScala.map(visitSortItemAndReplaceOrdinals).toSeq,
global = false,
withRepartitionByExpression(ctx, expressionList(distributeBy), query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
Expand Down Expand Up @@ -1825,24 +1825,27 @@ class AstBuilder extends DataTypeAstBuilder
}
visitNamedExpression(n)
}.toSeq
val groupByExpressionsWithOrdinals =
replaceOrdinalsInGroupingExpressions(groupByExpressions)
if (ctx.GROUPING != null) {
// GROUP BY ... GROUPING SETS (...)
// `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping
// expressions that do not exist in GROUPING SETS (...), and the value is always null.
// For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output
// of column `c` is always null.
val groupingSets =
ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)),
selectExpressions, query)
val groupingSetsWithOrdinals = visitGroupingSetAndReplaceOrdinals(ctx.groupingSet)
Aggregate(
Seq(GroupingSets(groupingSetsWithOrdinals, groupByExpressionsWithOrdinals)),
selectExpressions, query
)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (ctx.CUBE != null) {
Seq(Cube(groupByExpressions.map(Seq(_))))
Seq(Cube(groupByExpressionsWithOrdinals.map(Seq(_))))
} else if (ctx.ROLLUP != null) {
Seq(Rollup(groupByExpressions.map(Seq(_))))
Seq(Rollup(groupByExpressionsWithOrdinals.map(Seq(_))))
} else {
groupByExpressions
groupByExpressionsWithOrdinals
}
Aggregate(mappedGroupByExpressions, selectExpressions, query)
}
Expand All @@ -1856,16 +1859,20 @@ class AstBuilder extends DataTypeAstBuilder
} else {
expression(groupByExpr.expression)
}
})
Aggregate(groupByExpressions.toSeq, selectExpressions, query)
}).toSeq
Aggregate(
groupingExpressions = replaceOrdinalsInGroupingExpressions(groupByExpressions),
aggregateExpressions = selectExpressions,
child = query
)
}
}

override def visitGroupingAnalytics(
groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = {
val groupingSets = groupingAnalytics.groupingSet.asScala
.map(_.expression.asScala.map(e => expression(e)).toSeq)
if (groupingAnalytics.CUBE != null) {
val baseGroupingSet = if (groupingAnalytics.CUBE != null) {
// CUBE(A, B, (A, B), ()) is not supported.
if (groupingSets.exists(_.isEmpty)) {
throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics)
Expand All @@ -1889,6 +1896,9 @@ class AstBuilder extends DataTypeAstBuilder
}
GroupingSets(groupingSets.toSeq)
}
baseGroupingSet.withNewChildren(
newChildren = replaceOrdinalsInGroupingExpressions(baseGroupingSet.children)
).asInstanceOf[BaseGroupingSets]
}

/**
Expand Down Expand Up @@ -6532,12 +6542,12 @@ class AstBuilder extends DataTypeAstBuilder
case n: NamedExpression =>
newGroupingExpressions += n
newAggregateExpressions += n
// If the grouping expression is an integer literal, create [[UnresolvedOrdinal]] and
// [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final grouping
// and aggregate expressions, respectively. This will let the
// If the grouping expression is an [[UnresolvedOrdinal]], replace the ordinal value and
// create [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final
// grouping and aggregate expressions, respectively. This will let the
// [[ResolveOrdinalInOrderByAndGroupBy]] rule detect the ordinal in the aggregate list
// and replace it with the corresponding attribute from the child operator.
case Literal(v: Int, IntegerType) if conf.groupByOrdinal =>
case UnresolvedOrdinal(v: Int) =>
newGroupingExpressions += UnresolvedOrdinal(newAggregateExpressions.length + 1)
newAggregateExpressions += UnresolvedAlias(UnresolvedPipeAggregateOrdinal(v), None)
case e: Expression =>
Expand All @@ -6558,6 +6568,57 @@ class AstBuilder extends DataTypeAstBuilder
}
}

/**
* Visits [[SortItemContext]] and replaces top-level [[Literal]]s with [[UnresolvedOrdinal]] in
* resulting expression, if `orderByOrdinal` is enabled.
*/
private def visitSortItemAndReplaceOrdinals(sortItemContext: SortItemContext) = {
val visitedSortItem = visitSortItem(sortItemContext)
visitedSortItem.withNewChildren(
newChildren = Seq(replaceIntegerLiteralWithOrdinal(
expression = visitedSortItem.child,
canReplaceWithOrdinal = conf.orderByOrdinal
))
).asInstanceOf[SortOrder]
}

/**
* Replaces top-level integer [[Literal]]s with [[UnresolvedOrdinal]] in grouping expressions, if
* `groupByOrdinal` is enabled.
*/
private def replaceOrdinalsInGroupingExpressions(groupingExpressions: Seq[Expression]) =
groupingExpressions.map(groupByExpression =>
replaceIntegerLiteralWithOrdinal(
expression = groupByExpression,
canReplaceWithOrdinal = conf.groupByOrdinal
)
).toSeq

/**
* Visits grouping expressions in a [[GroupingSetContext]] and replaces top-level integer
* [[Literal]]s with [[UnresolvedOrdinal]]s in resulting expressions, if `groupByOrdinal` is
* enabled.
*/
private def visitGroupingSetAndReplaceOrdinals(groupingSet: List[GroupingSetContext]) = {
groupingSet.asScala.map(_.expression.asScala.map(e => {
val visitedExpression = expression(e)
replaceIntegerLiteralWithOrdinal(
expression = visitedExpression,
canReplaceWithOrdinal = conf.groupByOrdinal
)
}).toSeq).toSeq
}

/**
* Replaces integer [[Literal]] with [[UnresolvedOrdinal]] if `canReplaceWithOrdinal` is true.
*/
private def replaceIntegerLiteralWithOrdinal(
expression: Expression,
canReplaceWithOrdinal: Boolean = true) = expression match {
case Literal(value: Int, IntegerType) if canReplaceWithOrdinal => UnresolvedOrdinal(value)
case other => other
}

/**
* Check plan for any parameters.
* If it finds any throws UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT.
Expand Down
Loading