diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b5016fdb29d9..b64a4db1ab64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -206,7 +206,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -217,7 +217,7 @@ class Analyzer( def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan resolveOperatorsDown { - case u : UnresolvedRelation => + case u: UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) case other => @@ -234,19 +234,16 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. - case WithWindowDefinition(windowDefinitions, child) => - child.resolveOperators { - case p => p.transformExpressions { - case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => - val errorMessage = - s"Window specification $windowName is not defined in the WINDOW clause." - val windowSpecDefinition = - windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) - WindowExpression(c, windowSpecDefinition) - } - } + case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => + val errorMessage = + s"Window specification $windowName is not defined in the WINDOW clause." + val windowSpecDefinition = + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) + WindowExpression(c, windowSpecDefinition) + } } } @@ -274,7 +271,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -494,7 +491,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -527,7 +524,7 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p @@ -705,7 +702,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -897,7 +894,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -1091,7 +1088,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1147,7 +1144,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1171,7 +1168,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1307,7 +1304,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1444,7 +1441,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1460,7 +1457,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1509,7 +1506,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1686,7 +1683,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1746,7 +1743,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1787,7 +1784,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -2071,7 +2068,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2115,7 +2112,7 @@ class Analyzer( object ResolveRandomSeed extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) @@ -2131,7 +2128,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2166,25 +2163,21 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, - WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => - failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") - case WindowExpression(wf: WindowFunction, - s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => - WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = if (o.nonEmpty) { - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - we.copy(windowSpec = s.copy(frameSpecification = frame)) - } + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } + we.copy(windowSpec = s.copy(frameSpecification = frame)) } } @@ -2192,16 +2185,14 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => - failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + - s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + - s"ORDER BY window_ordering) from table") - case WindowExpression(rank: RankLike, spec) if spec.resolved => - val order = spec.orderSpec.map(_.child) - WindowExpression(rank.withOrder(order), spec) - } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + + s"ORDER BY window_ordering) from table") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) } } @@ -2210,8 +2201,8 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case j @ Join(left, right, UsingJoin(joinType, usingCols), _) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => @@ -2275,7 +2266,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2361,7 +2352,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2395,7 +2386,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2417,9 +2408,8 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - // This is actually called in the beginning of the optimization phase, and as a result - // is using transformUp rather than resolveOperators. This is also often called in the - // + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { plan transformUp { case SubqueryAlias(_, child) => child @@ -2431,7 +2421,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Union(children) if children.size == 1 => children.head } } @@ -2462,7 +2452,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2472,7 +2462,7 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(windowExprs, partitionSpec, orderSpec, child) => + case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), @@ -2496,7 +2486,7 @@ object CleanupAliases extends Rule[LogicalPlan] { * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2541,7 +2531,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2692,7 +2682,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan resolveOperatorsDown { + plan resolveOperators { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 80d5105c2de8..dbd4ed845e32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -86,7 +86,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -134,7 +134,7 @@ object ResolveHints { * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 7dd26b62b1fc..57a91d153cc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -318,7 +318,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -757,17 +757,18 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } } @@ -780,23 +781,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) + c.copy(children = newIndex +: newInputs) + } } } } @@ -1007,7 +1009,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 063ca0fc3252..5e2029c251ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -32,20 +32,17 @@ import org.apache.spark.sql.types.DataType */ case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case q: LogicalPlan => - q.transformExpressions { - case u @ UnresolvedFunction(fn, children, false) - if hasLambdaAndResolvedArguments(children) => - withPosition(u) { - catalog.lookupFunction(fn, children) match { - case func: HigherOrderFunction => func - case other => other.failAnalysis( - "A lambda function should only be used in a higher order function. However, " + - s"its class is ${other.getClass.getCanonicalName}, which is not a " + - s"higher order function.") - } - } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index feeb6553d106..af74693000c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 039acc1ea4fa..9404a809b453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -60,6 +60,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => */ def analyzed: Boolean = _analyzed + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. When + * `rule` does not apply to a given node, it is left unchanged. This function is similar to + * `transform`, but skips sub-trees that have already been marked as analyzed. + * Users should not expect a specific directionality. If a specific directionality is needed, + * [[resolveOperatorsUp]] or [[resolveOperatorsDown]] should be used. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + resolveOperatorsDown(rule) + } + /** * Returns a copy of this node where `rule` has been recursively applied first to all of its * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, @@ -68,10 +81,10 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => * * @param rule the function use to transform this nodes children */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + def resolveOperatorsUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { AnalysisHelper.allowInvokingTransformsInAnalyzer { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule)) if (self fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(self, identity[LogicalPlan]) @@ -87,7 +100,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => } } - /** Similar to [[resolveOperators]], but does it top-down. */ + /** Similar to [[resolveOperatorsUp]], but does it top-down. */ def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { AnalysisHelper.allowInvokingTransformsInAnalyzer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a4bf990ea9d6..f65948d39a1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1383,8 +1383,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1658,15 +1657,14 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan - val withGroupingKey = AppendColumns(func, inputPlan) + val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, - inputPlan.output, + logicalPlan.output, withGroupingKey.newColumns) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e1b049b6ceab..6b61e749e306 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3170180b32b8..949aa665527a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation =>