From b23dd7b37fb57a4063e95f390b8cab69d8b37afd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 26 Oct 2017 22:43:00 -0700 Subject: [PATCH] refactoring --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++++++++++-------- .../scala/org/apache/spark/sql/Column.scala | 10 ++++----- .../spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../sql/execution/WholeStageCodegenExec.scala | 5 ++--- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d6a962a14dc9..6384a141e83b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -783,6 +783,17 @@ class Analyzer( } } + private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(resolve(_, q)) + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -841,15 +852,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q.transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - } + q.mapExpressions(resolve(_, q)) } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8468a8a96349..92988680871a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -44,7 +44,7 @@ private[sql] object Column { e match { case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => a.aggregateFunction.toString - case expr => usePrettyExpression(expr).sql + case expr => toPrettySQL(expr) } } } @@ -137,7 +137,7 @@ class Column(val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) - override def toString: String = usePrettyExpression(expr).sql + override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -175,7 +175,7 @@ class Column(val expr: Expression) extends Logging { case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() + case _ => Alias(expr, toPrettySQL(expr))() } case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => @@ -184,7 +184,7 @@ class Column(val expr: Expression) extends Logging { // Wait until the struct is resolved. This will generate a nicer looking alias. case struct: CreateNamedStructLike => UnresolvedAlias(struct) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6b45790d5ff6..21e94fa8bb0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,7 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + case expr: Expression => Alias(expr, toPrettySQL(expr))() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index e37d133ff336..286cb3bb0767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -521,10 +521,9 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) - case j @ SortMergeJoinExec(_, _, _, _, left, right) => + case j: SortMergeJoinExec => // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) + j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child)))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) }