From 05872b718a7269173d65bd1e6c7d0bd029d17bab Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 20 Aug 2016 02:02:24 +0900 Subject: [PATCH 01/12] Fixed the issue that self-join or similar join patterns can cause wrong results --- .../sql/catalyst/analysis/Analyzer.scala | 29 ++++++++++--- .../sql/catalyst/analysis/unresolved.scala | 41 ++++++++++++++++++- .../expressions/namedExpressions.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 13 ++++++ .../spark/sql/catalyst/trees/TreeNode.scala | 20 ++++++++- .../spark/sql/DataFrameNaFunctions.scala | 14 +++---- .../scala/org/apache/spark/sql/Dataset.scala | 31 ++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 23 ++++++++++- 8 files changed, 149 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 41e0e6d65e9a..24fd1c8f6cfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -530,11 +530,16 @@ class Analyzer( val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { - case other => other transformExpressions { - case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) - } + case other => + val transformed = other transformExpressions { + case a: Attribute => + attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + } + + transformed.setPlanId(other.planId) + transformed } + newRight.setPlanId(right.planId) newRight } } @@ -597,7 +602,7 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = @@ -606,6 +611,20 @@ class Analyzer( result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) + case l: LazilyDeterminedAttribute => + val foundPlanOpt = q.findByBreadthFirst(_.planId == l.plan.planId) + val foundPlan = foundPlanOpt.getOrElse { + failAnalysis(s"""Cannot resolve column name "${l.name}" """) + } + + if (foundPlan == l.plan) { + l.namedExpr + } else { + foundPlan.resolveQuoted(l.name, resolver).getOrElse { + failAnalysis(s"""Cannot resolve column name "${l.name}" """ + + s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 235ae0478245..d8ef181bd717 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -161,7 +161,7 @@ object UnresolvedAttribute { } if (inBacktick) throw e nameParts += tmp.mkString - nameParts.toSeq + nameParts } } @@ -419,3 +419,42 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * This is used when we refer a column like `df("expr")` + * and determines which expression `df("expr")` should point to lazily. + * Normally, `df("expr")` should point the expression (say expr1 here.) which + * the logical plan in `df` outputs. but we have some cases that `df("expr")` should + * point to another expression (say expr2 here) rather than expr1 + * and in this case, expr2 is equally to expr1 except exprId. + * This will happen when datasets are self-joined or in similar situations and in this situation, + * logical plans and expressions of those outputs are re-created with new exprIds the analyzer. + * [[LazilyDeterminedAttribute()]] can treat this case properly + * to determine that `df("expr")` should point which expression in the analyzer. + * + * @param namedExpr The expression which a column reference should point to normally. + * @param plan The logical plan which contains the expression + * which the column reference should point to lazily. + */ +case class LazilyDeterminedAttribute( + namedExpr: NamedExpression)( + val plan: LogicalPlan) + extends Attribute with Unevaluable { + // We need to keep the constructor curried + // so that we can compare like df1("col1") == df2("col1") especially in case of test. + + override def name: String = namedExpr.name + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override lazy val resolved = false + + override def newInstance(): Attribute = throw new UnresolvedException(this, "newInstance") + override def withNullability(newNullability: Boolean): Attribute = + throw new UnresolvedException(this, "withNullability") + override def withName(newName: String): Attribute = + throw new UnresolvedException(this, "withName") + override def withQualifier(newQualifier: Option[String]): Attribute = + throw new UnresolvedException(this, "withQualifier") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 306a99d5a37b..5fbc15c72fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -218,7 +218,7 @@ case class AttributeReference( extends Attribute with Unevaluable { /** - * Returns true iff the expression id is the same for both attributes. + * Returns true if the expression id is the same for both attributes. */ def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d7799151d93..370aa22280da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -25,11 +25,20 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType +object LogicalPlan { + private val curId = new java.util.concurrent.atomic.AtomicLong() + def newPlanId: Long = curId.getAndIncrement() +} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false + // Logical plans are identified by planId + // even though a logical plan is replaced by the analyzer + // to deduplicate expressions which have same exprId. + private var _planId: Long = LogicalPlan.newPlanId + /** * Marks this plan as already analyzed. This should only be called by CheckAnalysis. */ @@ -42,6 +51,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def analyzed: Boolean = _analyzed + private[catalyst] def setPlanId(planId: Long): Unit = { _planId = planId } + + def planId: Long = _planId + /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 24a2dc9d3b35..180f4f25e15d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID import scala.collection.Map -import scala.collection.mutable.Stack +import scala.collection.mutable.{ArrayBuffer, Stack} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -108,6 +108,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case false => children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) } } + def findByBreadthFirst(f: BaseType => Boolean): Option[BaseType] = { + val queue = new ArrayBuffer[BaseType] + var foundOpt: Option[BaseType] = None + queue.append(this) + + // Do breadth first search to find most exact logical plan + while (queue.nonEmpty && foundOpt.isEmpty) { + val currentNode = queue.remove(0) + f(currentNode) match { + case true => foundOpt = Option(currentNode) + case false => + val childPlans = currentNode.children.reverse + childPlans.foreach(queue.append(_)) + } + } + foundOpt + } + /** * Runs the given function on this node and then recursively on [[children]]. * @param f the function to be applied to each node in the tree. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ad00966a917a..036ee595d9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -161,7 +161,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { fillCol[Double](f, value) } else { - df.col(f.name) + df.colInternal(f.name) } } df.select(projections : _*) @@ -188,7 +188,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { fillCol[String](f, value) } else { - df.col(f.name) + df.colInternal(f.name) } } df.select(projections : _*) @@ -363,7 +363,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } else if (f.dataType == targetColumnType && shouldReplace) { replaceCol(f, replacementMap) } else { - df.col(f.name) + df.colInternal(f.name) } } df.select(projections : _*) @@ -395,7 +395,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } - }.getOrElse(df.col(f.name)) + }.getOrElse(df.colInternal(f.name)) } df.select(projections : _*) } @@ -407,8 +407,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types - case _ => df.col(quotedColName) + nanvl(df.colInternal(quotedColName), lit(null)) // nanvl only supports these types + case _ => df.colInternal(quotedColName) } coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) } @@ -420,7 +420,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { - val keyExpr = df.col(col.name).expr + val keyExpr = df.colInternal(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => Seq(buildExpr(source), buildExpr(target)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6da99ce0dd68..663443cd9f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -899,7 +899,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): Dataset[T] = { - sort((sortCol +: sortCols).map(apply) : _*) + sort((sortCol +: sortCols).map(colInternal) : _*) } /** @@ -952,12 +952,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def col(colName: String): Column = colName match { - case "*" => - Column(ResolvedStar(queryExecution.analyzed.output)) - case _ => - val expr = resolve(colName) - Column(expr) + def col(colName: String): Column = withStarResolved(colName) { + val candidateExpr = resolve(colName) + val expr = LazilyDeterminedAttribute(candidateExpr)(logicalPlan) + Column(expr) } /** @@ -1705,7 +1703,8 @@ class Dataset[T] private[sql]( val convert = CatalystTypeConverters.createToCatalystConverter(dataType) f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } - val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) + val generator = + UserDefinedGenerator(elementSchema, rowFunction, colInternal(inputColumn).expr :: Nil) withPlan { Generate(generator, join = true, outer = false, @@ -1836,6 +1835,12 @@ class Dataset[T] private[sql]( case Column(u: UnresolvedAttribute) => queryExecution.analyzed.resolveQuoted( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) + case Column(l: LazilyDeterminedAttribute) => + val foundExpression = + logicalPlan.findByBreadthFirst(_.planId == l.plan.planId) + .flatMap(_.resolveQuoted(l.name, sparkSession.sessionState.analyzer.resolver)) + .getOrElse(l.namedExpr) + foundExpression case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output @@ -2628,6 +2633,16 @@ class Dataset[T] private[sql]( } } + private[sql] def colInternal(colName: String): Column = withStarResolved(colName) { + val expr = resolve(colName) + Column(expr) + } + + private def withStarResolved(colName: String)(f: => Column): Column = colName match { + case "*" => Column(ResolvedStar(queryExecution.analyzed.output)) + case _ => f + } + /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 499f3180379c..9a24ac203ebd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1469,7 +1469,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) } } - test("sameResult() on aggregate") { val df = spark.range(100) val agg1 = df.groupBy().count() @@ -1578,4 +1577,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + + test("""SPARK-17154: df("column_name") should return correct result when we do self-join""") { + val df = Seq( + (1, "a", "A"), + (2, "b", "B"), + (3, "c", "C"), + (4, "d", "D"), + (5, "e", "E")).toDF("col1", "col2", "col3") + val filtered = df.filter("col1 != 3").select("col1", "col2") + val joined = filtered.join(df, filtered("col1") === df("col1"), "inner") + val selected1 = joined.select(df("col3")) + + checkAnswer(selected1, Row("A") :: Row("B") :: Row("D") :: Row("E") :: Nil) + + val rightOuterJoined = filtered.join(df, filtered("col1") === df("col1"), "right") + val selected2 = rightOuterJoined.select(df("col1")) + + checkAnswer(selected2, Row(1) :: Row(2) :: Row(3) :: Row(4) :: Row(5) ::Nil) + + val selected3 = rightOuterJoined.select(filtered("col1")) + checkAnswer(selected3, Row(1) :: Row(2) :: Row(null) :: Row(4) :: Row(5) :: Nil) + } } From 91cb915b4e6c3c4d24fab3f1e772e7e361d4c088 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 20 Aug 2016 10:37:41 +0900 Subject: [PATCH 02/12] Fix ScriptTransformationSuite --- .../sql/hive/execution/ScriptTransformationSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index a8e81d7a3c42..1a9fa54e700a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -53,7 +53,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -67,7 +67,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -82,7 +82,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), @@ -99,7 +99,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), @@ -116,7 +116,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { val e = intercept[SparkException] { val plan = new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, From dd0ddbcba4bb9b11df4f2353e956e5e1114ceade Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 20 Aug 2016 15:01:24 +0900 Subject: [PATCH 03/12] Fixed further more test cases --- .../sql/execution/joins/ExistenceJoinSuite.scala | 13 +++++++------ .../spark/sql/execution/joins/InnerJoinSuite.scala | 10 +++++----- .../spark/sql/execution/joins/OuterJoinSuite.scala | 4 ++-- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e..3574db4bce82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -62,16 +62,16 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr + private lazy val singleConditionEQ = (left.colInternal("a") === right.colInternal("c")).expr private lazy val composedConditionEQ = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") === right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } private lazy val composedConditionNEQ = { - And((left.col("a") < right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") < right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because @@ -252,6 +252,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { LeftAnti, left, rightUniqueKey, - (left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr, + (left.colInternal("a") === rightUniqueKey.colInternal("c") && + left.colInternal("b") < rightUniqueKey.colInternal("d")).expr, Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 35dab63672c0..1eefc79b8ffb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -219,7 +219,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, one match per row", myUpperCaseData, myLowerCaseData, - () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + () => (myUpperCaseData.colInternal("N") === myLowerCaseData.colInternal("n")).expr, Seq( (1, "A", 1, "a"), (2, "B", 2, "b"), @@ -235,7 +235,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, multiple matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => (left.colInternal("a") === right.colInternal("a")).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -252,7 +252,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, no matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => (left.colInternal("a") === right.colInternal("a")).expr, Seq.empty ) } @@ -264,7 +264,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, null safe", left, right, - () => (left.col("b") <=> right.col("b")).expr, + () => (left.colInternal("b") <=> right.colInternal("b")).expr, Seq( (1, 0, 1, 0), (2, null, 2, null) @@ -280,7 +280,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "SPARK-15822 - test structs as keys", left, right, - () => (left.col("key") === right.col("key")).expr, + () => (left.colInternal("key") === right.colInternal("key")).expr, Seq( (Row(0, 0), "L0", Row(0, 0), "R0"), (Row(1, 1), "L1", Row(1, 1), "R1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b39..297621663e1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -57,8 +57,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") === right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because From 74eb4aa43b390be820550105e2d2f7eaf47a8d09 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 20 Aug 2016 23:35:20 +0900 Subject: [PATCH 04/12] Fixed analysis error in sorting --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 24fd1c8f6cfa..0be428e6654f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -712,6 +712,20 @@ class Analyzer( withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) + case l: LazilyDeterminedAttribute => + val foundPlanOpt = plan.findByBreadthFirst(_.planId == l.plan.planId) + val foundPlan = foundPlanOpt.getOrElse { + failAnalysis(s"""Cannot resolve column name "${l.name}" """) + } + + if (foundPlan == l.plan) { + l.namedExpr + } else { + foundPlan.resolveQuoted(l.name, resolver).getOrElse { + failAnalysis(s"""Cannot resolve column name "${l.name}" """ + + s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") + } + } } } catch { case a: AnalysisException if !throws => expr From 48a0775e80cc91340cb0754c62b35868f319cf44 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 20 Aug 2016 23:52:35 +0900 Subject: [PATCH 05/12] Refectored --- .../sql/catalyst/analysis/Analyzer.scala | 49 ++++++++----------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0be428e6654f..1de8c94e8e3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -611,20 +611,7 @@ class Analyzer( result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) - case l: LazilyDeterminedAttribute => - val foundPlanOpt = q.findByBreadthFirst(_.planId == l.plan.planId) - val foundPlan = foundPlanOpt.getOrElse { - failAnalysis(s"""Cannot resolve column name "${l.name}" """) - } - - if (foundPlan == l.plan) { - l.namedExpr - } else { - foundPlan.resolveQuoted(l.name, resolver).getOrElse { - failAnalysis(s"""Cannot resolve column name "${l.name}" """ + - s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") - } - } + case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, q) } } @@ -697,6 +684,25 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } + private def resolveLazilyDeterminedAttribute( + expr: LazilyDeterminedAttribute, + plan: LogicalPlan): Expression = { + + val foundPlanOpt = plan.findByBreadthFirst(_.planId == expr.plan.planId) + val foundPlan = foundPlanOpt.getOrElse { + failAnalysis(s"""Cannot resolve column name "${expr.name}" """) + } + + if (foundPlan == expr.plan) { + expr.namedExpr + } else { + foundPlan.resolveQuoted(expr.name, resolver).getOrElse { + failAnalysis(s"""Cannot resolve column name "${expr.name}" """ + + s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") + } + } + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -712,20 +718,7 @@ class Analyzer( withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) - case l: LazilyDeterminedAttribute => - val foundPlanOpt = plan.findByBreadthFirst(_.planId == l.plan.planId) - val foundPlan = foundPlanOpt.getOrElse { - failAnalysis(s"""Cannot resolve column name "${l.name}" """) - } - - if (foundPlan == l.plan) { - l.namedExpr - } else { - foundPlan.resolveQuoted(l.name, resolver).getOrElse { - failAnalysis(s"""Cannot resolve column name "${l.name}" """ + - s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") - } - } + case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, plan) } } catch { case a: AnalysisException if !throws => expr From 9ddc9d858fc3d5b269a8a762b356a545f70646d6 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 21 Aug 2016 03:53:33 +0900 Subject: [PATCH 06/12] Fixed flatMapGroupsInR to ensure expressions for Deserializer is resolved --- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7cfd1cdc7d5d..4b1be6296a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -21,9 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} @@ -409,7 +407,9 @@ class RelationalGroupedDataset protected[sql]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], outputSchema: StructType): DataFrame = { - val groupingNamedExpressions = groupingExprs.map(alias) + val groupingNamedExpressions = groupingExprs + .map(df.sparkSession.sessionState.analyzer.resolveExpression(_, df.logicalPlan)) + .map(alias) val groupingCols = groupingNamedExpressions.map(Column(_)) val groupingDataFrame = df.select(groupingCols : _*) val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) From 148b6d56ac3a08e3113e1773fe275bb8ea9efac5 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 21 Aug 2016 17:37:58 +0900 Subject: [PATCH 07/12] Fixed style --- .../sql/catalyst/analysis/Analyzer.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1de8c94e8e3d..0694aa404b23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -688,19 +688,19 @@ class Analyzer( expr: LazilyDeterminedAttribute, plan: LogicalPlan): Expression = { - val foundPlanOpt = plan.findByBreadthFirst(_.planId == expr.plan.planId) - val foundPlan = foundPlanOpt.getOrElse { - failAnalysis(s"""Cannot resolve column name "${expr.name}" """) - } + val foundPlanOpt = plan.findByBreadthFirst(_.planId == expr.plan.planId) + val foundPlan = foundPlanOpt.getOrElse { + failAnalysis(s"""Cannot resolve column name "${expr.name}" """) + } - if (foundPlan == expr.plan) { - expr.namedExpr - } else { - foundPlan.resolveQuoted(expr.name, resolver).getOrElse { - failAnalysis(s"""Cannot resolve column name "${expr.name}" """ + - s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") - } + if (foundPlan == expr.plan) { + expr.namedExpr + } else { + foundPlan.resolveQuoted(expr.name, resolver).getOrElse { + failAnalysis(s"""Cannot resolve column name "${expr.name}" """ + + s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") } + } } protected[sql] def resolveExpression( From 021977f8e1b736e9aafc3955fd8926462ec1a2da Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 28 Aug 2016 05:05:26 +0900 Subject: [PATCH 08/12] Implemented another idea --- python/pyspark/sql/tests.py | 1 - .../sql/catalyst/analysis/Analyzer.scala | 64 +++++++++++-------- .../sql/catalyst/analysis/unresolved.scala | 43 +------------ .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 16 ++--- .../scala/org/apache/spark/sql/Dataset.scala | 33 ++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- 9 files changed, 77 insertions(+), 94 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fc41701b5922..d3e9a62139be 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1137,7 +1137,6 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0694aa404b23..bd00fc8afeda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -603,15 +603,21 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + withPosition(u) { + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, q, targetPlanId) + case None => + q.resolveChildren(nameParts, resolver).getOrElse(u) + } + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) - case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, q) } } @@ -684,22 +690,18 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } - private def resolveLazilyDeterminedAttribute( - expr: LazilyDeterminedAttribute, - plan: LogicalPlan): Expression = { - - val foundPlanOpt = plan.findByBreadthFirst(_.planId == expr.plan.planId) - val foundPlan = foundPlanOpt.getOrElse { - failAnalysis(s"""Cannot resolve column name "${expr.name}" """) - } - - if (foundPlan == expr.plan) { - expr.namedExpr - } else { - foundPlan.resolveQuoted(expr.name, resolver).getOrElse { - failAnalysis(s"""Cannot resolve column name "${expr.name}" """ + - s"""among (${foundPlan.schema.fieldNames.mkString(", ")})""") - } + private[sql] def resolveExpressionFromSpecificLogicalPlan( + nameParts: Seq[String], + planToSearchFrom: LogicalPlan, + targetPlanId: Long): Expression = { + lazy val name = UnresolvedAttribute(nameParts).name + planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match { + case Some(foundPlan) => + foundPlan.resolve(nameParts, resolver).getOrElse { + failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}") + } + case None => + failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}") } } @@ -714,11 +716,16 @@ class Analyzer( try { expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => + withPosition(u) { + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId) + case None => plan.resolve(nameParts, resolver).getOrElse(u) + } + } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) - case l: LazilyDeterminedAttribute => resolveLazilyDeterminedAttribute(l, plan) } } catch { case a: AnalysisException if !throws => expr @@ -942,12 +949,17 @@ class Analyzer( plan transformDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { - case u @ UnresolvedAttribute(nameParts) => + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => withPosition(u) { try { - outer.resolve(nameParts, resolver) match { - case Some(outerAttr) => OuterReference(outerAttr) - case None => u + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, outer, targetPlanId) + case None => + outer.resolve(nameParts, resolver) match { + case Some(outerAttr) => OuterReference(outerAttr) + case None => u + } } } catch { case _: AnalysisException => u diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d8ef181bd717..109fbf921203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -83,7 +83,9 @@ case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { +case class UnresolvedAttribute( + nameParts: Seq[String], + targetPlanIdOpt: Option[Long] = None) extends Attribute with Unevaluable { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -419,42 +421,3 @@ case class UnresolvedOrdinal(ordinal: Int) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } - -/** - * This is used when we refer a column like `df("expr")` - * and determines which expression `df("expr")` should point to lazily. - * Normally, `df("expr")` should point the expression (say expr1 here.) which - * the logical plan in `df` outputs. but we have some cases that `df("expr")` should - * point to another expression (say expr2 here) rather than expr1 - * and in this case, expr2 is equally to expr1 except exprId. - * This will happen when datasets are self-joined or in similar situations and in this situation, - * logical plans and expressions of those outputs are re-created with new exprIds the analyzer. - * [[LazilyDeterminedAttribute()]] can treat this case properly - * to determine that `df("expr")` should point which expression in the analyzer. - * - * @param namedExpr The expression which a column reference should point to normally. - * @param plan The logical plan which contains the expression - * which the column reference should point to lazily. - */ -case class LazilyDeterminedAttribute( - namedExpr: NamedExpression)( - val plan: LogicalPlan) - extends Attribute with Unevaluable { - // We need to keep the constructor curried - // so that we can compare like df1("col1") == df2("col1") especially in case of test. - - override def name: String = namedExpr.name - override def exprId: ExprId = throw new UnresolvedException(this, "exprId") - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") - override lazy val resolved = false - - override def newInstance(): Attribute = throw new UnresolvedException(this, "newInstance") - override def withNullability(newNullability: Boolean): Attribute = - throw new UnresolvedException(this, "withNullability") - override def withName(newName: String): Attribute = - throw new UnresolvedException(this, "withName") - override def withQualifier(newQualifier: Option[String]): Attribute = - throw new UnresolvedException(this, "withQualifier") -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b96b744b4fa9..6320cc107154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -149,7 +149,7 @@ object ExpressionEncoder { } else { val input = GetColumnByOrdinal(index, enc.schema) val deserialized = enc.deserializer.transformUp { - case UnresolvedAttribute(nameParts) => + case UnresolvedAttribute(nameParts, _) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 283e4d43ba2b..0b65e3f850c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1168,8 +1168,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { val attr = ctx.fieldName.getText expression(ctx.base) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ attr) + case UnresolvedAttribute(nameParts, targetPlanId) => + UnresolvedAttribute(nameParts :+ attr, targetPlanId) case e => UnresolvedExtractValue(e, Literal(attr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 036ee595d9fc..c6b8455a36c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -161,7 +161,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { fillCol[Double](f, value) } else { - df.colInternal(f.name) + df.col(f.name) } } df.select(projections : _*) @@ -188,7 +188,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { fillCol[String](f, value) } else { - df.colInternal(f.name) + df.col(f.name) } } df.select(projections : _*) @@ -363,7 +363,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } else if (f.dataType == targetColumnType && shouldReplace) { replaceCol(f, replacementMap) } else { - df.colInternal(f.name) + df.col(f.name) } } df.select(projections : _*) @@ -395,7 +395,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } - }.getOrElse(df.colInternal(f.name)) + }.getOrElse(df.col(f.name)) } df.select(projections : _*) } @@ -407,8 +407,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - nanvl(df.colInternal(quotedColName), lit(null)) // nanvl only supports these types - case _ => df.colInternal(quotedColName) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + case _ => df.col(quotedColName) } coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) } @@ -420,8 +420,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { - val keyExpr = df.colInternal(col.name).expr - def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) + val keyExpr = df.col(col.name).expr + def buildExpr(v: Any) = Cast(Literal(v), col.dataType) val branches = replacementMap.flatMap { case (source, target) => Seq(buildExpr(source), buildExpr(target)) }.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 663443cd9f1a..1c1af387e32a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -899,7 +899,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): Dataset[T] = { - sort((sortCol +: sortCols).map(colInternal) : _*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -953,8 +953,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def col(colName: String): Column = withStarResolved(colName) { - val candidateExpr = resolve(colName) - val expr = LazilyDeterminedAttribute(candidateExpr)(logicalPlan) + val expr = UnresolvedAttribute( + UnresolvedAttribute.parseAttributeName(colName), + Some(queryExecution.analyzed.planId)) Column(expr) } @@ -1703,8 +1704,7 @@ class Dataset[T] private[sql]( val convert = CatalystTypeConverters.createToCatalystConverter(dataType) f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } - val generator = - UserDefinedGenerator(elementSchema, rowFunction, colInternal(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) withPlan { Generate(generator, join = true, outer = false, @@ -1832,15 +1832,17 @@ class Dataset[T] private[sql]( */ def drop(col: Column): DataFrame = { val expression = col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( - u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) - case Column(l: LazilyDeterminedAttribute) => - val foundExpression = - logicalPlan.findByBreadthFirst(_.planId == l.plan.planId) - .flatMap(_.resolveQuoted(l.name, sparkSession.sessionState.analyzer.resolver)) - .getOrElse(l.namedExpr) - foundExpression + case Column(u @ UnresolvedAttribute(nameParts, targetPlanIdOpt)) => + val plan = queryExecution.analyzed + val analyzer = sparkSession.sessionState.analyzer + val resolver = analyzer.resolver + + targetPlanIdOpt match { + case Some(targetPlanId) => + analyzer.resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId) + case None => + plan.resolveQuoted(u.name, resolver).getOrElse(u) + } case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output @@ -2633,6 +2635,9 @@ class Dataset[T] private[sql]( } } + /** Another version of `col` which resolve an expression immediately. + * Mainly intended to use for test for example in case of passing columns to a SparkPlan. + */ private[sql] def colInternal(colName: String): Column = withStarResolved(colName) { val expr = resolve(colName) Column(expr) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9a24ac203ebd..93e096da803f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -607,7 +607,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(id, name, age, salary) }.toSeq) assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) - assert(df("id") == person("id")) + val dfAnalyzer = df.sparkSession.sessionState.analyzer + val personAnalyzer = person.sparkSession.sessionState.analyzer + assert(dfAnalyzer.resolveExpression(df("id").expr, df.queryExecution.analyzed) == + personAnalyzer.resolveExpression(person("id").expr, person.queryExecution.analyzed)) } test("drop top level columns that contains dot") { @@ -1469,6 +1472,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) } } + test("sameResult() on aggregate") { val df = spark.range(100) val agg1 = df.groupBy().count() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index b29e822add8b..d3f7216875cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -242,7 +242,7 @@ object SparkPlanTest { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => + case UnresolvedAttribute(Seq(u), _) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } From ccf71fcc8c366eddb4f43b6a79105e47576d99d5 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 28 Aug 2016 13:00:40 +0900 Subject: [PATCH 09/12] Modified error message --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f885b1f32457..e38660bed204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -692,10 +692,10 @@ class Analyzer( planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match { case Some(foundPlan) => foundPlan.resolve(nameParts, resolver).getOrElse { - failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}") + failAnalysis(s"Could not find $name in ${foundPlan.output.mkString(", ")}") } case None => - failAnalysis(s"Could not find $name in ${planToSearchFrom.output.mkString(", ")}") + failAnalysis(s"Could not find $name in any logical plan.") } } From 437ac9934db4b9e2ac686b8c46e31621cf819033 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 28 Sep 2016 21:28:45 +0900 Subject: [PATCH 10/12] Reverted the previous change which prohibitted self-join --- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9e11150eb884..24417f9082d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -490,31 +490,31 @@ class Analyzer( right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.newInstance() (oldVersion, newVersion) case oldVersion: SerializeFromObject - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))) // Handle projects that create conflicting aliases. case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) case oldVersion @ Window(windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) } // Only handle first case, others will be fixed on the next pass. @@ -698,9 +698,7 @@ class Analyzer( lazy val name = UnresolvedAttribute(nameParts).name planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match { case Some(foundPlan) => - foundPlan.resolve(nameParts, resolver).getOrElse { - failAnalysis(s"Could not find $name in ${foundPlan.output.mkString(", ")}") - } + foundPlan.resolve(nameParts, resolver).get case None => failAnalysis(s"Could not find $name in any logical plan.") } From 15bf529b2aab40b068485b748220744b64ac68f8 Mon Sep 17 00:00:00 2001 From: sarutak Date: Fri, 28 Oct 2016 13:31:47 +0900 Subject: [PATCH 11/12] Fixed style --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c12c11c424df..84c5237a0ef0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1601,7 +1601,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } - + test("""SPARK-17154: df("column_name") should return correct result when we do self-join""") { val df = Seq( (1, "a", "A"), From 5d1ff3e601f9583d289a88f708230639a25a18b2 Mon Sep 17 00:00:00 2001 From: sarutak Date: Thu, 17 Nov 2016 15:45:38 +0900 Subject: [PATCH 12/12] Fix compile error --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdbfb657bcb2..59189b47f06f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -196,9 +196,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { val parts = ctx.expression.asScala.map { pVal => expression(pVal) match { - case UnresolvedAttribute(name :: Nil) => + case UnresolvedAttribute(name :: Nil, _) => name -> None - case cmp @ EqualTo(UnresolvedAttribute(name :: Nil), constant: Literal) => + case cmp @ EqualTo(UnresolvedAttribute(name :: Nil, _), constant: Literal) => name -> Option(constant.toString) case _ => throw new ParseException("Invalid partition filter specification", ctx) @@ -219,7 +219,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { expression(pVal) match { case EqualNullSafe(_, _) => throw new ParseException("'<=>' operator is not allowed in partition specification.", ctx) - case cmp @ BinaryComparison(UnresolvedAttribute(name :: Nil), constant: Literal) => + case cmp @ BinaryComparison(UnresolvedAttribute(name :: Nil, _), constant: Literal) => cmp.withNewChildren(Seq(AttributeReference(name, StringType)(), constant)) case _ => throw new ParseException("Invalid partition filter specification", ctx)