diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 18fd68ec5ef5..de2a9e08e232 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1263,7 +1263,6 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 73e92066b85c..2ab434ff9b07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -556,31 +556,31 @@ class Analyzer( right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.newInstance() (oldVersion, newVersion) case oldVersion: SerializeFromObject - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))) // Handle projects that create conflicting aliases. case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) case oldVersion @ Window(windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) } // Only handle first case, others will be fixed on the next pass. @@ -597,11 +597,16 @@ class Analyzer( val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { - case other => other transformExpressions { - case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) - } + case other => + val transformed = other transformExpressions { + case a: Attribute => + attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + } + + transformed.setPlanId(other.planId) + transformed } + newRight.setPlanId(right.planId) newRight } } @@ -664,11 +669,18 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) => + q transformExpressionsUp { + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + withPosition(u) { + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, q, targetPlanId) + case None => + q.resolveChildren(nameParts, resolver).getOrElse(u) + } + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -746,6 +758,19 @@ class Analyzer( exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } + private[sql] def resolveExpressionFromSpecificLogicalPlan( + nameParts: Seq[String], + planToSearchFrom: LogicalPlan, + targetPlanId: Long): Expression = { + lazy val name = UnresolvedAttribute(nameParts).name + planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match { + case Some(foundPlan) => + foundPlan.resolve(nameParts, resolver).get + case None => + failAnalysis(s"Could not find $name in any logical plan.") + } + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -757,8 +782,14 @@ class Analyzer( try { expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => + withPosition(u) { + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId) + case None => plan.resolve(nameParts, resolver).getOrElse(u) + } + } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -986,12 +1017,17 @@ class Analyzer( plan transformDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { - case u @ UnresolvedAttribute(nameParts) => + case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) => withPosition(u) { try { - outer.resolve(nameParts, resolver) match { - case Some(outerAttr) => OuterReference(outerAttr) - case None => u + targetPlanIdOpt match { + case Some(targetPlanId) => + resolveExpressionFromSpecificLogicalPlan(nameParts, outer, targetPlanId) + case None => + outer.resolve(nameParts, resolver) match { + case Some(outerAttr) => OuterReference(outerAttr) + case None => u + } } } catch { case _: AnalysisException => u diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 36ed9ba50372..b5f21493018d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -83,7 +83,9 @@ case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { +case class UnresolvedAttribute( + nameParts: Seq[String], + targetPlanIdOpt: Option[Long] = None) extends Attribute with Unevaluable { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -162,7 +164,7 @@ object UnresolvedAttribute { } if (inBacktick) throw e nameParts += tmp.mkString - nameParts.toSeq + nameParts } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3757eccfa2dd..4925c97faf36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -156,7 +156,7 @@ object ExpressionEncoder { } else { val input = GetColumnByOrdinal(index, enc.schema) val deserialized = enc.deserializer.transformUp { - case UnresolvedAttribute(nameParts) => + case UnresolvedAttribute(nameParts, _) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c842f85af693..62a99499a346 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -220,7 +220,7 @@ case class AttributeReference( extends Attribute with Unevaluable { /** - * Returns true iff the expression id is the same for both attributes. + * Returns true if the expression id is the same for both attributes. */ def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3969fdb0ffee..9c94f16f5a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1156,8 +1156,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { val attr = ctx.fieldName.getText expression(ctx.base) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ attr) + case UnresolvedAttribute(nameParts, targetPlanId) => + UnresolvedAttribute(nameParts :+ attr, targetPlanId) case e => UnresolvedExtractValue(e, Literal(attr)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f634cb29ddb..843e730c72be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,11 +26,20 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType +object LogicalPlan { + private val curId = new java.util.concurrent.atomic.AtomicLong() + def newPlanId: Long = curId.getAndIncrement() +} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false + // Logical plans are identified by planId + // even though a logical plan is replaced by the analyzer + // to deduplicate expressions which have same exprId. + private var _planId: Long = LogicalPlan.newPlanId + /** * Marks this plan as already analyzed. This should only be called by CheckAnalysis. */ @@ -43,6 +52,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def analyzed: Boolean = _analyzed + private[catalyst] def setPlanId(planId: Long): Unit = { _planId = planId } + + def planId: Long = _planId + /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8cc16d662b60..dcd4fe9292e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID import scala.collection.Map +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -110,6 +111,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) } } + def findByBreadthFirst(f: BaseType => Boolean): Option[BaseType] = { + val queue = new ArrayBuffer[BaseType] + var foundOpt: Option[BaseType] = None + queue.append(this) + + // Do breadth first search to find most exact logical plan + while (queue.nonEmpty && foundOpt.isEmpty) { + val currentNode = queue.remove(0) + f(currentNode) match { + case true => foundOpt = Option(currentNode) + case false => + val childPlans = currentNode.children.reverse + childPlans.foreach(queue.append(_)) + } + } + foundOpt + } + /** * Runs the given function on this node and then recursively on [[children]]. * @param f the function to be applied to each node in the tree. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a..c929f198289f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -421,7 +421,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { val keyExpr = df.col(col.name).expr - def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) + def buildExpr(v: Any) = Cast(Literal(v), col.dataType) val branches = replacementMap.flatMap { case (source, target) => Seq(buildExpr(source), buildExpr(target)) }.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fd75d5153802..ed84b2809d3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1067,12 +1067,11 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def col(colName: String): Column = colName match { - case "*" => - Column(ResolvedStar(queryExecution.analyzed.output)) - case _ => - val expr = resolve(colName) - Column(expr) + def col(colName: String): Column = withStarResolved(colName) { + val expr = UnresolvedAttribute( + UnresolvedAttribute.parseAttributeName(colName), + Some(queryExecution.analyzed.planId)) + Column(expr) } /** @@ -1949,9 +1948,17 @@ class Dataset[T] private[sql]( */ def drop(col: Column): DataFrame = { val expression = col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( - u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) + case Column(u @ UnresolvedAttribute(nameParts, targetPlanIdOpt)) => + val plan = queryExecution.analyzed + val analyzer = sparkSession.sessionState.analyzer + val resolver = analyzer.resolver + + targetPlanIdOpt match { + case Some(targetPlanId) => + analyzer.resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId) + case None => + plan.resolveQuoted(u.name, resolver).getOrElse(u) + } case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output @@ -2786,6 +2793,19 @@ class Dataset[T] private[sql]( } } + /** Another version of `col` which resolve an expression immediately. + * Mainly intended to use for test for example in case of passing columns to a SparkPlan. + */ + private[sql] def colInternal(colName: String): Column = withStarResolved(colName) { + val expr = resolve(colName) + Column(expr) + } + + private def withStarResolved(colName: String)(f: => Column): Column = colName match { + case "*" => Column(ResolvedStar(queryExecution.analyzed.output)) + case _ => f + } + /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0fe8d87ebd6b..b149f7c4664c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -411,7 +411,9 @@ class RelationalGroupedDataset protected[sql]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], outputSchema: StructType): DataFrame = { - val groupingNamedExpressions = groupingExprs.map(alias) + val groupingNamedExpressions = groupingExprs + .map(df.sparkSession.sessionState.analyzer.resolveExpression(_, df.logicalPlan)) + .map(alias) val groupingCols = groupingNamedExpressions.map(Column(_)) val groupingDataFrame = df.select(groupingCols : _*) val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f4df80fd9c93..7a955c36881d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -627,7 +627,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(id, name, age, salary) }.toSeq) assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) - assert(df("id") == person("id")) + val dfAnalyzer = df.sparkSession.sessionState.analyzer + val personAnalyzer = person.sparkSession.sessionState.analyzer + assert(dfAnalyzer.resolveExpression(df("id").expr, df.queryExecution.analyzed) == + personAnalyzer.resolveExpression(person("id").expr, person.queryExecution.analyzed)) } test("drop top level columns that contains dot") { @@ -1601,6 +1604,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + test("""SPARK-17154: df("column_name") should return correct result when we do self-join""") { + val df = Seq( + (1, "a", "A"), + (2, "b", "B"), + (3, "c", "C"), + (4, "d", "D"), + (5, "e", "E")).toDF("col1", "col2", "col3") + val filtered = df.filter("col1 != 3").select("col1", "col2") + val joined = filtered.join(df, filtered("col1") === df("col1"), "inner") + val selected1 = joined.select(df("col3")) + + checkAnswer(selected1, Row("A") :: Row("B") :: Row("D") :: Row("E") :: Nil) + + val rightOuterJoined = filtered.join(df, filtered("col1") === df("col1"), "right") + val selected2 = rightOuterJoined.select(df("col1")) + + checkAnswer(selected2, Row(1) :: Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + + val selected3 = rightOuterJoined.select(filtered("col1")) + checkAnswer(selected3, Row(1) :: Row(2) :: Row(null) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") { withTable("bar") { withTempView("foo") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index b29e822add8b..d3f7216875cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -242,7 +242,7 @@ object SparkPlanTest { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => + case UnresolvedAttribute(Seq(u), _) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e..3574db4bce82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -62,16 +62,16 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr + private lazy val singleConditionEQ = (left.colInternal("a") === right.colInternal("c")).expr private lazy val composedConditionEQ = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") === right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } private lazy val composedConditionNEQ = { - And((left.col("a") < right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") < right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because @@ -252,6 +252,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { LeftAnti, left, rightUniqueKey, - (left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr, + (left.colInternal("a") === rightUniqueKey.colInternal("c") && + left.colInternal("b") < rightUniqueKey.colInternal("d")).expr, Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece11225..e662d18686f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -219,7 +219,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, one match per row", myUpperCaseData, myLowerCaseData, - () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + () => (myUpperCaseData.colInternal("N") === myLowerCaseData.colInternal("n")).expr, Seq( (1, "A", 1, "a"), (2, "B", 2, "b"), @@ -235,7 +235,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, multiple matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => (left.colInternal("a") === right.colInternal("a")).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -252,7 +252,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, no matches", left, right, - () => (left.col("a") === right.col("a")).expr, + () => (left.colInternal("a") === right.colInternal("a")).expr, Seq.empty ) } @@ -264,7 +264,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "inner join, null safe", left, right, - () => (left.col("b") <=> right.col("b")).expr, + () => (left.colInternal("b") <=> right.colInternal("b")).expr, Seq( (1, 0, 1, 0), (2, null, 2, null) @@ -280,7 +280,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { "SPARK-15822 - test structs as keys", left, right, - () => (left.col("key") === right.col("key")).expr, + () => (left.colInternal("key") === right.colInternal("key")).expr, Seq( (Row(0, 0), "L0", Row(0, 0), "R0"), (Row(1, 1), "L1", Row(1, 1), "R1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b39..297621663e1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -57,8 +57,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { )), new StructType().add("c", IntegerType).add("d", DoubleType)) private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) + And((left.colInternal("a") === right.colInternal("c")).expr, + LessThan(left.colInternal("b").expr, right.colInternal("d").expr)) } // Note: the input dataframes and expression must be evaluated lazily because diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0e837766e2ea..24eed27fda74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -54,7 +54,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -68,7 +68,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, @@ -83,7 +83,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), @@ -100,7 +100,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { checkAnswer( rowsDf, (child: SparkPlan) => new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), @@ -117,7 +117,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { val e = intercept[SparkException] { val plan = new ScriptTransformation( - input = Seq(rowsDf.col("a").expr), + input = Seq(rowsDf.colInternal("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan,