From dd59446ea77a9ff3b5a408b617e3d4268da7b2b9 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 29 Jan 2020 18:56:37 -0800 Subject: [PATCH 01/11] initial commit --- .../sql/catalyst/analysis/Analyzer.scala | 16 +-- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../expressions/namedExpressions.scala | 2 - .../sql/catalyst/expressions/package.scala | 117 +++++++----------- .../spark/sql/catalyst/identifiers.scala | 16 +-- .../plans/logical/basicLogicalOperators.scala | 17 ++- .../catalyst/plans/logical/v2Commands.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 3 +- .../sql/connector/catalog/CatalogV2Util.scala | 15 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 10 +- .../catalog/CatalogV2UtilSuite.scala | 11 +- .../sql-tests/results/group-by-filter.sql.out | 32 ++--- .../invalid-correlation.sql.out | 4 +- .../benchmark/TPCDSQueryBenchmark.scala | 2 +- .../command/PlanResolutionSuite.scala | 89 +++++++------ 15 files changed, 175 insertions(+), 165 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 15ebf6971d9d..4719dbccbc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -812,7 +812,11 @@ class Analyzer( case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved => lookupV2Relation(u.multipartIdentifier) - .map(v2Relation => i.copy(table = v2Relation)) + .map { + EliminateSubqueryAliases(_) match { + case r: DataSourceV2Relation => i.copy(table = r) + } + } .getOrElse(i) case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) => @@ -827,14 +831,10 @@ class Analyzer( /** * Performs the lookup of DataSourceV2 Tables from v2 catalog. */ - private def lookupV2Relation(identifier: Seq[String]): Option[DataSourceV2Relation] = + private def lookupV2Relation(identifier: Seq[String]): Option[LogicalPlan] = expandRelationName(identifier) match { case NonSessionCatalogAndIdentifier(catalog, ident) => - CatalogV2Util.loadTable(catalog, ident) match { - case Some(table) => - Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident))) - case None => None - } + CatalogV2Util.loadRelation(catalog, ident) case _ => None } } @@ -922,7 +922,7 @@ class Analyzer( case v1Table: V1Table => v1SessionCatalog.getRelation(v1Table.v1Table) case table => - DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + CatalogV2Util.getRelation(catalog, ident, table) } val key = catalog.name +: ident.namespace :+ ident.name Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d6fc1dc6ddc3..de510589781c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -425,8 +425,8 @@ trait CheckAnalysis extends PredicateHelper { case _ => } - case alter: AlterTable if alter.childrenResolved => - val table = alter.table + case alter @ AlterTable(_, _, SubqueryAlias(_, table: NamedRelation), _) + if alter.childrenResolved => def findField(operation: String, fieldName: Array[String]): StructField = { // include collections because structs nested in maps and arrays may be altered val field = table.schema.findNestedField(fieldName, includeCollections = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3362353e2662..02e90f8458c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -236,8 +236,6 @@ case class AttributeReference( val qualifier: Seq[String] = Seq.empty[String]) extends Attribute with Unevaluable { - // currently can only handle qualifier of length 2 - require(qualifier.length <= 2) /** * Returns true iff the expression id is the same for both attributes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 7164b6b82adb..e208a37fb7c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -23,7 +23,6 @@ import com.google.common.collect.Maps import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -128,90 +127,66 @@ package object expressions { m.mapValues(_.distinct).map(identity) } - /** Map to use for direct case insensitive attribute lookups. */ - @transient private lazy val direct: Map[String, Seq[Attribute]] = { + /** Attribute name to attributes */ + @transient private val attrsMap: Map[String, Seq[Attribute]] = { unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) } - /** Map to use for qualified case insensitive attribute lookups with 2 part key */ - @transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = { - // key is 2 part: table/alias and name - val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy { - a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) - } - unique(grouped) - } - - /** Map to use for qualified case insensitive attribute lookups with 3 part key */ - @transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = { - // key is 3 part: database name, table name and name - val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a => - (a.qualifier.head.toLowerCase(Locale.ROOT), - a.qualifier.last.toLowerCase(Locale.ROOT), - a.name.toLowerCase(Locale.ROOT)) - } - unique(grouped) - } - /** Perform attribute resolution given a name and a resolver. */ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { - // Collect matching attributes given a name and a lookup. - def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { - candidates.toSeq.flatMap(_.collect { - case a if resolver(a.name, name) => a.withName(name) - }) + // Returns true if the `short` qualifier is a subset of the last elements of + // `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"), + // but not a subset of Seq("a", "b", "b"). + def matchQualifier(short: Seq[String], long: Seq[String]): Boolean = { + (long.length >= short.length) && + long.takeRight(short.length) + .zip(short) + .filterNot(x => resolver(x._1, x._2)) + .isEmpty } - // Find matches for the given name assuming that the 1st two parts are qualifier - // (i.e. database name and table name) and the 3rd part is the actual column name. - // - // For example, consider an example where "db1" is the database name, "a" is the table name - // and "b" is the column name and "c" is the struct field name. - // If the name parts is db1.a.b.c, then Attribute will match - // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element - var matches: (Seq[Attribute], Seq[String]) = nameParts match { - case dbPart +: tblPart +: name +: nestedFields => - val key = (dbPart.toLowerCase(Locale.ROOT), - tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) - val attributes = collectMatches(name, qualified3Part.get(key)).filter { - a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last)) - } - (attributes, nestedFields) - case _ => - (Seq.empty, Seq.empty) + // Collect attributes that match the given name and qualifier. + // A match occurs if + // 1) the given name matches the attribute's name according to the resolver. + // 2) the given qualifier is a subset of the attribute's qualifier. + def collectMatches( + name: String, + qualifier: Seq[String], + candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) => + a.withName(name) + }) } - // If there are no matches, then find matches for the given name assuming that - // the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the - // 2nd part is the actual name. This returns a tuple of - // matched attributes and a list of parts that are to be resolved. - // - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - if (matches._1.isEmpty) { - matches = nameParts match { - case qualifier +: name +: nestedFields => - val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) - val attributes = collectMatches(name, qualified.get(key)).filter { a => - resolver(qualifier, a.qualifier.last) - } - (attributes, nestedFields) - case _ => - (Seq.empty[Attribute], Seq.empty[String]) + // Iterate each string in `nameParts` in a reverse order and try to match the attributes + // considering the current string as the attribute name. For example, if `nameParts` is + // Seq("a", "b", "c"), the match will be performed in the following order: + // 1) name = "c", qualifier = Seq("a", "b") + // 2) name = "b", qualifier = Seq("a") + // 3) name = "a", qualifier = Seq() + // Note that the match is performed in the reverse order in order to match the longest + // qualifier as possible. If a match is found, the remaining portion of `nameParts` + // is also returned as nested fields. + val matches = nameParts.zipWithIndex.reverseIterator.flatMap { case (name, index) => + val matched = collectMatches( + name, + nameParts.take(index), + attrsMap.get(name.toLowerCase(Locale.ROOT))) + if (matched.nonEmpty) { + (matched, nameParts.takeRight(nameParts.length - index - 1)) :: Nil + } else { + Nil } } - // If none of attributes match database.table.column pattern or - // `table.column` pattern, we try to resolve it as a column. - val (candidates, nestedFields) = matches match { - case (Seq(), _) => - val name = nameParts.head - val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) - (attributes, nameParts.tail) - case _ => matches + if (matches.isEmpty) { + return None } + // Note that `matches` is an iterator, and only the first match will be used. + val (candidates, nestedFields) = matches.next + def name = UnresolvedAttribute(nameParts).name candidates match { case Seq(a) if nestedFields.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index deceec73dda3..6861ffe07fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -49,19 +49,21 @@ sealed trait IdentifierWithDatabase { /** * Encapsulates an identifier that is either a alias name or an identifier that has table - * name and optionally a database name. + * name and a namespace. * The SubqueryAlias node keeps track of the qualifier using the information in this structure - * @param identifier - Is an alias name or a table name - * @param database - Is a database name and is optional + * @param name - Is an alias name or a table name + * @param namespace - Is a namespace */ -case class AliasIdentifier(identifier: String, database: Option[String]) - extends IdentifierWithDatabase { +case class AliasIdentifier(name: String, namespace: Seq[String]) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def this(identifier: String) = this(identifier, Seq()) - def this(identifier: String) = this(identifier, None) + override def toString: String = (namespace :+ name).quoted } object AliasIdentifier { - def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier) + def apply(name: String): AliasIdentifier = new AliasIdentifier(name) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 40db8b6f49dc..222dd07797b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.types._ import org.apache.spark.util.random.RandomSampler @@ -849,18 +850,18 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi /** * Aliased subquery. * - * @param name the alias identifier for this subquery. + * @param identifier the alias identifier for this subquery. * @param child the logical plan of this subquery. */ case class SubqueryAlias( - name: AliasIdentifier, + identifier: AliasIdentifier, child: LogicalPlan) extends OrderPreservingUnaryNode { - def alias: String = name.identifier + def alias: String = identifier.name override def output: Seq[Attribute] = { - val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + val qualifierList = identifier.namespace :+ alias child.output.map(_.withQualifier(qualifierList)) } override def doCanonicalize(): LogicalPlan = child.canonicalized @@ -877,7 +878,13 @@ object SubqueryAlias { identifier: String, database: String, child: LogicalPlan): SubqueryAlias = { - SubqueryAlias(AliasIdentifier(identifier, Some(database)), child) + SubqueryAlias(AliasIdentifier(identifier, Seq(database)), child) + } + + def apply( + identifier: Identifier, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier.name, identifier.namespace), child) } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index c04e56355a68..289db7dc07df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -389,7 +389,7 @@ case class DropTable( case class AlterTable( catalog: TableCatalog, ident: Identifier, - table: NamedRelation, + table: LogicalPlan, changes: Seq[TableChange]) extends Command { override lazy val resolved: Boolean = table.resolved && { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ba1eeb38e247..56a198763b4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -27,7 +27,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.sql.catalyst.IdentifierWithDatabase +import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase} import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.errors._ @@ -780,6 +780,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case exprId: ExprId => true case field: StructField => true case id: IdentifierWithDatabase => true + case alias: AliasIdentifier => true case join: JoinType => true case spec: BucketSpec => true case catalog: CatalogTable => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 0fabe4df6c9a..cf1eedfed932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -22,8 +22,9 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.plans.logical.AlterTable +import org.apache.spark.sql.catalyst.AliasIdentifier +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, LogicalPlan, SubqueryAlias} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} @@ -285,8 +286,14 @@ private[sql] object CatalogV2Util { case _: NoSuchNamespaceException => None } - def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[NamedRelation] = { - loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident))) + def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[LogicalPlan] = { + loadTable(catalog, ident).map(getRelation(catalog, ident, _)) + } + + def getRelation(catalog: CatalogPlugin, ident: Identifier, table: Table): LogicalPlan = { + SubqueryAlias( + ident, + DataSourceV2Relation.create(table, Some(catalog), Some(ident))) } def isSessionCatalog(catalog: CatalogPlugin): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 0e094bc06b05..0795f378eaf9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -433,10 +433,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { // Converts AliasIdentifier to JSON assertJSON( - AliasIdentifier("alias"), + AliasIdentifier("alias", Seq("ns1", "ns2")), JObject( "product-class" -> JString(classOf[AliasIdentifier].getName), - "identifier" -> "alias")) + "name" -> "alias", + "namespace" -> "[ns1, ns2]")) // Converts SubqueryAlias to JSON assertJSON( @@ -445,8 +446,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { JObject( "class" -> classOf[SubqueryAlias].getName, "num-children" -> 1, - "name" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName), - "identifier" -> "t1"), + "identifier" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName), + "name" -> "t1", + "namespace" -> JArray(Nil)), "child" -> 0), JObject( "class" -> classOf[JsonTestTreeNode].getName, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index 7a9a7f52ff8f..f545f2ac112c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.catalog import org.mockito.Mockito.{mock, when} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.StructType @@ -32,9 +33,11 @@ class CatalogV2UtilSuite extends SparkFunSuite { when(testCatalog.loadTable(ident)).thenReturn(table) val r = CatalogV2Util.loadRelation(testCatalog, ident) assert(r.isDefined) - assert(r.get.isInstanceOf[DataSourceV2Relation]) - val v2Relation = r.get.asInstanceOf[DataSourceV2Relation] - assert(v2Relation.catalog.exists(_ == testCatalog)) - assert(v2Relation.identifier.exists(_ == ident)) + r.get match { + case SubqueryAlias(_, v2Relation: DataSourceV2Relation) => + assert(v2Relation.catalog.exists(_ == testCatalog)) + assert(v2Relation.identifier.exists(_ == ident)) + case _ => fail() + } } } diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index a032678e90fe..a4c7c2cf90cd 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -369,13 +369,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE exists#x [dept_id#x]) AS avg(salary) FILTER (WHERE exists(dept_id))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -395,13 +395,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT exists#x [dept_id#x]) AS sum(salary) FILTER (WHERE (NOT exists(dept_id)))#x] : +- Project [state#x] : +- Filter (dept_id#x = outer(dept_id#x)) -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -420,13 +420,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, avg(salary#x) AS avg(salary)#x, avg(salary#x) FILTER (WHERE dept_id#x IN (list#x [])) AS avg(salary) FILTER (WHERE (dept_id IN (listquery())))#x] : +- Distinct : +- Project [dept_id#x] -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; @@ -445,13 +445,13 @@ org.apache.spark.sql.AnalysisException IN/EXISTS predicate sub-queries can only be used in Filter/Join and a few commands: Aggregate [dept_id#x], [dept_id#x, sum(salary#x) AS sum(salary)#x, sum(salary#x) FILTER (WHERE NOT dept_id#x IN (list#x [])) AS sum(salary) FILTER (WHERE (NOT (dept_id IN (listquery()))))#x] : +- Distinct : +- Project [dept_id#x] -: +- SubqueryAlias `dept` +: +- SubqueryAlias dept : +- Project [dept_id#x, dept_name#x, state#x] -: +- SubqueryAlias `DEPT` +: +- SubqueryAlias DEPT : +- LocalRelation [dept_id#x, dept_name#x, state#x] -+- SubqueryAlias `emp` ++- SubqueryAlias emp +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] - +- SubqueryAlias `EMP` + +- SubqueryAlias EMP +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x] ; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 1599634ff9ef..ec7ecf28754e 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -110,8 +110,8 @@ struct<> org.apache.spark.sql.AnalysisException Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: Aggregate [min(outer(t2a#x)) AS min(outer())#x] -+- SubqueryAlias `t3` ++- SubqueryAlias t3 +- Project [t3a#x, t3b#x, t3c#x] - +- SubqueryAlias `t3` + +- SubqueryAlias t3 +- LocalRelation [t3a#x, t3b#x, t3c#x] ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index c93d27f02c68..ad3d79760adf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -81,7 +81,7 @@ object TPCDSQueryBenchmark extends SqlBasedBenchmark { val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.analyzed.foreach { case SubqueryAlias(alias, _: LogicalRelation) => - queryRelations.add(alias.identifier) + queryRelations.add(alias.name) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) case HiveTableRelation(tableMeta, _, _, _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 70b9b7ec12ea..0aff678e7944 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -726,7 +726,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed3, expected3) } else { parsed1 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.setProperty("test", "test"), TableChange.setProperty("comment", "new_comment"))) @@ -734,7 +734,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed2 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.removeProperty("comment"), TableChange.removeProperty("test"))) @@ -742,7 +742,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed3 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.removeProperty("comment"), TableChange.removeProperty("test"))) @@ -785,7 +785,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.setProperty("a", "1"), TableChange.setProperty("b", "0.1"), @@ -809,7 +809,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq(TableChange.setProperty("location", "new location"))) case _ => fail("Expect AlterTable, but got:\n" + parsed.treeString) } @@ -883,33 +883,34 @@ class PlanResolutionSuite extends AnalysisTest { val parsed4 = parseAndResolve(sql4) parsed1 match { - case DeleteFromTable(_: DataSourceV2Relation, None) => - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed1.treeString) + case DeleteFromTable(AsDataSourceV2Relation(_), None) => + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed1.treeString) } parsed2 match { case DeleteFromTable( - _: DataSourceV2Relation, + AsDataSourceV2Relation(_), Some(EqualTo(name: UnresolvedAttribute, StringLiteral("Robert")))) => assert(name.name == "name") - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed2.treeString) + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed2.treeString) } parsed3 match { case DeleteFromTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Some(EqualTo(name: UnresolvedAttribute, StringLiteral("Robert")))) => assert(name.name == "t.name") - case _ => fail("Expect DeleteFromTable, bug got:\n" + parsed3.treeString) + case _ => fail("Expect DeleteFromTable, but got:\n" + parsed3.treeString) } parsed4 match { - case DeleteFromTable(SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + case DeleteFromTable( + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Some(InSubquery(values, query))) => assert(values.size == 1 && values.head.isInstanceOf[UnresolvedAttribute]) assert(values.head.asInstanceOf[UnresolvedAttribute].name == "t.name") query match { - case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", None), + case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") @@ -942,7 +943,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed1 match { case UpdateTable( - _: DataSourceV2Relation, + AsDataSourceV2Relation(_), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), None) => @@ -954,7 +955,9 @@ class PlanResolutionSuite extends AnalysisTest { parsed2 match { case UpdateTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias( + AliasIdentifier("t", Seq()), + AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), None) => @@ -966,7 +969,9 @@ class PlanResolutionSuite extends AnalysisTest { parsed3 match { case UpdateTable( - SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + SubqueryAlias( + AliasIdentifier("t", Seq()), + AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), Some(EqualTo(p: UnresolvedAttribute, IntegerLiteral(1)))) => @@ -978,14 +983,16 @@ class PlanResolutionSuite extends AnalysisTest { } parsed4 match { - case UpdateTable(SubqueryAlias(AliasIdentifier("t", None), _: DataSourceV2Relation), + case UpdateTable( + SubqueryAlias(AliasIdentifier("t", Seq()), + AsDataSourceV2Relation(_)), Seq(Assignment(key: UnresolvedAttribute, IntegerLiteral(32))), Some(InSubquery(values, query))) => assert(key.name == "t.age") assert(values.size == 1 && values.head.isInstanceOf[UnresolvedAttribute]) assert(values.head.asInstanceOf[UnresolvedAttribute].name == "t.name") query match { - case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", None), + case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()), UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))), _, _, _) => assert(projects.size == 1 && projects.head.name == "s.name") @@ -1051,14 +1058,14 @@ class PlanResolutionSuite extends AnalysisTest { val parsed3 = parseAndResolve(sql3) parsed1 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.updateColumnType(Array("i"), LongType))) case _ => fail("expect AlterTable") } parsed2 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.updateColumnType(Array("i"), LongType), TableChange.updateColumnComment(Array("i"), "new comment"))) @@ -1066,7 +1073,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed3 match { - case AlterTable(_, _, _: DataSourceV2Relation, changes) => + case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => assert(changes == Seq( TableChange.updateColumnComment(Array("i"), "new comment"))) case _ => fail("expect AlterTable") @@ -1102,10 +1109,10 @@ class PlanResolutionSuite extends AnalysisTest { val catlogIdent = if (isSessionCatlog) v2SessionCatalog else testCat val tableIdent = if (isSessionCatlog) "v2Table" else "tab" parsed match { - case AlterTable(_, _, r: DataSourceV2Relation, _) => + case AlterTable(_, _, AsDataSourceV2Relation(r), _) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) - case Project(_, r: DataSourceV2Relation) => + case Project(_, AsDataSourceV2Relation(r)) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) case InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) => @@ -1182,8 +1189,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql1) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1208,8 +1215,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql2) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, @@ -1234,8 +1241,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql3) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), mergeCondition, Seq(DeleteAction(None), UpdateAction(None, updateAssigns)), Seq(InsertAction(None, insertAssigns))) => @@ -1258,8 +1265,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql4) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: Project), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), source: Project), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1287,8 +1294,8 @@ class PlanResolutionSuite extends AnalysisTest { """.stripMargin parseAndResolve(sql5) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), target: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), source: Project), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), source: Project), mergeCondition, Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), UpdateAction(Some(EqualTo(ul: AttributeReference, StringLiteral("update"))), @@ -1322,8 +1329,8 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql1) match { case MergeIntoTable( - target: DataSourceV2Relation, - source: DataSourceV2Relation, + AsDataSourceV2Relation(target), + AsDataSourceV2Relation(source), _, Seq(DeleteAction(None), UpdateAction(None, updateAssigns)), Seq(InsertAction( @@ -1429,8 +1436,8 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql) match { case MergeIntoTable( - SubqueryAlias(AliasIdentifier("target", None), _: DataSourceV2Relation), - SubqueryAlias(AliasIdentifier("source", None), _: DataSourceV2Relation), + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(_)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(_)), EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute), Seq( DeleteAction(Some(EqualTo(dl: UnresolvedAttribute, StringLiteral("delete")))), @@ -1457,3 +1464,11 @@ class PlanResolutionSuite extends AnalysisTest { } // TODO: add tests for more commands. } + +object AsDataSourceV2Relation { + def unapply(plan: LogicalPlan): Option[DataSourceV2Relation] = plan match { + case SubqueryAlias(_, r: DataSourceV2Relation) => Some(r) + case _ => None + } +} + From 713c0fb7e2f476c1b0763c03c8d381d5a3deaf3f Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 29 Jan 2020 19:31:02 -0800 Subject: [PATCH 02/11] fix compilation error in hive test --- .../org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 20bafd832d0d..b8ef44b096ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias(AliasIdentifier("vw1", Some("default")), _) => x + case x @ SubqueryAlias(AliasIdentifier("vw1", Seq("default")), _) => x } assert(aliases.size == 1) } From e744f81e8f46c25ebfd22f077f79384a5a5f2c22 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 29 Jan 2020 20:23:50 -0800 Subject: [PATCH 03/11] catalog name support + add tests --- .../sql/connector/catalog/CatalogV2Util.scala | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index cf1eedfed932..fee5bb5624df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -292,7 +292,7 @@ private[sql] object CatalogV2Util { def getRelation(catalog: CatalogPlugin, ident: Identifier, table: Table): LogicalPlan = { SubqueryAlias( - ident, + Identifier.of(catalog.name +: ident.namespace, ident.name), DataSourceV2Relation.create(table, Some(catalog), Some(ident))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 04e5a8dfd78b..20581bb65638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -678,6 +678,44 @@ class DataSourceV2SQLSuite } } + test("qualified column names for v2 tables") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, point struct) USING foo") + sql(s"INSERT INTO $t VALUES (1, (10, 20))") + + checkAnswer( + sql(s"SELECT testcat.ns1.ns2.tbl.id, testcat.ns1.ns2.tbl.point.x FROM $t"), + Row(1, 10)) + checkAnswer(sql(s"SELECT ns1.ns2.tbl.id, ns1.ns2.tbl.point.x FROM $t"), Row(1, 10)) + checkAnswer(sql(s"SELECT ns2.tbl.id, ns2.tbl.point.x FROM $t"), Row(1, 10)) + checkAnswer(sql(s"SELECT tbl.id, tbl.point.x FROM $t"), Row(1, 10)) + + val ex = intercept[AnalysisException] { + sql(s"SELECT ns1.ns2.ns3.tbl.id from $t") + } + assert(ex.getMessage.contains("cannot resolve '`ns1.ns2.ns3.tbl.id`")) + } + } + + test("qualified column names for v1 tables") { + // unset this config to use the default v2 session catalog. + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + + withTable("t") { + sql("CREATE TABLE t USING json AS SELECT 1 AS i") + checkAnswer(sql("select default.t.i from spark_catalog.t"), Row(1)) + checkAnswer(sql("select t.i from spark_catalog.default.t"), Row(1)) + checkAnswer(sql("select default.t.i from spark_catalog.default.t"), Row(1)) + + // catalog name cannot be used for v1 tables. + val ex = intercept[AnalysisException] { + sql(s"select spark_catalog.default.t.i from spark_catalog.default.t") + } + assert(ex.getMessage.contains("cannot resolve '`spark_catalog.default.t.i`")) + } + } + test("InsertInto: append - across catalog") { val t1 = "testcat.ns1.ns2.tbl" val t2 = "testcat2.db.tbl" From 631304aa9e95eb53d39d657b49be86a1146780f6 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 29 Jan 2020 21:03:04 -0800 Subject: [PATCH 04/11] test fix --- .../apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index f545f2ac112c..fbfcc567cd2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType class CatalogV2UtilSuite extends SparkFunSuite { test("Load relation should encode the identifiers for V2Relations") { val testCatalog = mock(classOf[TableCatalog]) - val ident = mock(classOf[Identifier]) + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") val table = mock(classOf[Table]) when(table.schema()).thenReturn(mock(classOf[StructType])) when(testCatalog.loadTable(ident)).thenReturn(table) From c414d7b1eaa3c2129ff78bbedf11525221801d85 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 30 Jan 2020 11:43:49 -0800 Subject: [PATCH 05/11] Added unit tests --- .../sql/catalyst/expressions/package.scala | 3 +- .../AttributeResolutionSuite.scala | 105 ++++++++++++++++++ 2 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index e208a37fb7c1..10c2aa77f937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -141,8 +141,7 @@ package object expressions { (long.length >= short.length) && long.takeRight(short.length) .zip(short) - .filterNot(x => resolver(x._1, x._2)) - .isEmpty + .forall(x => resolver(x._1, x._2)) } // Collect attributes that match the given name and qualifier. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala new file mode 100644 index 000000000000..7b6932781c03 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class AttributeResolutionSuite extends SparkFunSuite { + val resolver = caseInsensitiveResolution + + test("basic attribute resolution with namespaces") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3"))) + + // Try to match attribute reference with name "a" with qualifier "ns1.ns2". + Seq(Seq("ns2", "a"), Seq("ns1", "ns2", "a")).foreach { nameParts => + attrs.resolve(nameParts, resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + // Resolution is ambiguous. + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("a"), resolver) + } + assert(ex.getMessage.contains( + "Reference 'a' is ambiguous, could be: ns1.ns2.a, ns1.ns2.ns3.a.")) + + // Non-matching cases. + Seq(Seq("ns1", "ns2"), Seq("ns1", "a")).foreach { nameParts => + val resolved = attrs.resolve(nameParts, resolver) + assert(resolved.isEmpty) + } + } + + test("attribute resolution with nested fields") { + val attrType = StructType(Seq(StructField("aa", IntegerType), StructField("bb", IntegerType))) + val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "ns2"))) + + val resolved = attrs.resolve(Seq("ns1", "ns2", "a", "aa"), resolver) + resolved match { + case Some(Alias(_, name)) => assert(name == "aa") + case _ => fail() + } + + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("ns1", "ns2", "a", "cc"), resolver) + } + assert(ex.getMessage.contains("No such struct field cc in aa, bb")) + } + + test("attribute resolution with case insensitive resolver") { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2"))) + attrs.resolve(Seq("Ns1", "nS2", "A"), caseInsensitiveResolution) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0)) && attr.name == "A") + case _ => fail() + } + } + + test("attribute resolution with case sensitive resolver") { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2"))) + assert(attrs.resolve(Seq("Ns1", "nS2", "A"), caseSensitiveResolution).isEmpty) + assert(attrs.resolve(Seq("ns1", "ns2", "A"), caseSensitiveResolution).isEmpty) + attrs.resolve(Seq("ns1", "ns2", "a"), caseSensitiveResolution) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + test("attribute resolution should try to match the longest qualifier") { + // We have two attributes: + // 1) "a.b" where "a" is the name and "b" is the nested field. + // 2) "a.b.a" where "b" is the name, left-side "a" is the qualifier and the right-side "a" + // is the nested field. + // When "a.b" is resolved, "b" is tried first as the name, so it is resolved to #2 attribute. + val a1Type = StructType(Seq(StructField("b", IntegerType))) + val a2Type = StructType(Seq(StructField("a", IntegerType))) + val attrs = Seq( + AttributeReference("a", a1Type)(), + AttributeReference("b", a2Type)(qualifier = Seq("a"))) + attrs.resolve(Seq("a", "b"), resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(1))) + case _ => fail() + } + } +} From a507120730d0b7e9c0252bf6523634fc998e1dee Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 31 Jan 2020 10:39:50 -0800 Subject: [PATCH 06/11] Address PR comments --- .../sql/catalyst/analysis/Analyzer.scala | 29 ++++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../spark/sql/catalyst/identifiers.scala | 6 +- .../plans/logical/basicLogicalOperators.scala | 6 +- .../catalyst/plans/logical/v2Commands.scala | 2 +- .../sql/connector/catalog/CatalogV2Util.scala | 15 ++--- .../AttributeResolutionSuite.scala | 55 ++++++++++++------- .../sql/catalyst/trees/TreeNodeSuite.scala | 4 +- .../catalog/CatalogV2UtilSuite.scala | 13 ++--- .../command/PlanResolutionSuite.scala | 18 +++--- 10 files changed, 84 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4719dbccbc78..3761d84fabe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -798,6 +798,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp { case u: UnresolvedRelation => lookupV2Relation(u.multipartIdentifier) + .map(SubqueryAlias(u.multipartIdentifier, _)) .getOrElse(u) case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) => @@ -812,11 +813,7 @@ class Analyzer( case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved => lookupV2Relation(u.multipartIdentifier) - .map { - EliminateSubqueryAliases(_) match { - case r: DataSourceV2Relation => i.copy(table = r) - } - } + .map(v2Relation => i.copy(table = v2Relation)) .getOrElse(i) case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) => @@ -825,16 +822,22 @@ class Analyzer( .getOrElse(alter) case u: UnresolvedV2Relation => - CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u) + CatalogV2Util.loadRelation(u.catalog, u.tableName) + .map(SubqueryAlias(u.originalNameParts, _)) + .getOrElse(u) } /** * Performs the lookup of DataSourceV2 Tables from v2 catalog. */ - private def lookupV2Relation(identifier: Seq[String]): Option[LogicalPlan] = + private def lookupV2Relation(identifier: Seq[String]): Option[DataSourceV2Relation] = expandRelationName(identifier) match { case NonSessionCatalogAndIdentifier(catalog, ident) => - CatalogV2Util.loadRelation(catalog, ident) + CatalogV2Util.loadTable(catalog, ident) match { + case Some(table) => + Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident))) + case None => None + } case _ => None } } @@ -885,7 +888,13 @@ class Analyzer( } case u: UnresolvedRelation => - lookupRelation(u.multipartIdentifier).map(resolveViews).getOrElse(u) + lookupRelation(u.multipartIdentifier) + .map { + case r: DataSourceV2Relation => SubqueryAlias(u.multipartIdentifier, r) + case other => other + } + .map(resolveViews) + .getOrElse(u) case u @ UnresolvedTable(identifier) => lookupTableOrView(identifier).map { @@ -922,7 +931,7 @@ class Analyzer( case v1Table: V1Table => v1SessionCatalog.getRelation(v1Table.v1Table) case table => - CatalogV2Util.getRelation(catalog, ident, table) + DataSourceV2Relation.create(table, Some(catalog), Some(ident)) } val key = catalog.name +: ident.namespace :+ ident.name Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index de510589781c..d6fc1dc6ddc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -425,8 +425,8 @@ trait CheckAnalysis extends PredicateHelper { case _ => } - case alter @ AlterTable(_, _, SubqueryAlias(_, table: NamedRelation), _) - if alter.childrenResolved => + case alter: AlterTable if alter.childrenResolved => + val table = alter.table def findField(operation: String, fieldName: Array[String]): StructField = { // include collections because structs nested in maps and arrays may be altered val field = table.schema.findNestedField(fieldName, includeCollections = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 6861ffe07fce..460a2db41a93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -52,14 +52,14 @@ sealed trait IdentifierWithDatabase { * name and a namespace. * The SubqueryAlias node keeps track of the qualifier using the information in this structure * @param name - Is an alias name or a table name - * @param namespace - Is a namespace + * @param qualifier - Is a qualifier */ -case class AliasIdentifier(name: String, namespace: Seq[String]) { +case class AliasIdentifier(name: String, qualifier: Seq[String]) { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ def this(identifier: String) = this(identifier, Seq()) - override def toString: String = (namespace :+ name).quoted + override def toString: String = (qualifier :+ name).quoted } object AliasIdentifier { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 222dd07797b4..54e5ff7aeb75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -861,7 +861,7 @@ case class SubqueryAlias( def alias: String = identifier.name override def output: Seq[Attribute] = { - val qualifierList = identifier.namespace :+ alias + val qualifierList = identifier.qualifier :+ alias child.output.map(_.withQualifier(qualifierList)) } override def doCanonicalize(): LogicalPlan = child.canonicalized @@ -882,9 +882,9 @@ object SubqueryAlias { } def apply( - identifier: Identifier, + multipartIdentifier: Seq[String], child: LogicalPlan): SubqueryAlias = { - SubqueryAlias(AliasIdentifier(identifier.name, identifier.namespace), child) + SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child) } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 289db7dc07df..c04e56355a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -389,7 +389,7 @@ case class DropTable( case class AlterTable( catalog: TableCatalog, ident: Identifier, - table: LogicalPlan, + table: NamedRelation, changes: Seq[TableChange]) extends Command { override lazy val resolved: Boolean = table.resolved && { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index fee5bb5624df..0fabe4df6c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -22,9 +22,8 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.AliasIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} +import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} @@ -286,14 +285,8 @@ private[sql] object CatalogV2Util { case _: NoSuchNamespaceException => None } - def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[LogicalPlan] = { - loadTable(catalog, ident).map(getRelation(catalog, ident, _)) - } - - def getRelation(catalog: CatalogPlugin, ident: Identifier, table: Table): LogicalPlan = { - SubqueryAlias( - Identifier.of(catalog.name +: ident.namespace, ident.name), - DataSourceV2Relation.create(table, Some(catalog), Some(ident))) + def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[NamedRelation] = { + loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident))) } def isSessionCatalog(catalog: CatalogPlugin): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala index 7b6932781c03..8ef0baa039e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala @@ -27,60 +27,77 @@ class AttributeResolutionSuite extends SparkFunSuite { test("basic attribute resolution with namespaces") { val attrs = Seq( - AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2")), - AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3"))) + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t1")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "ns3", "t2"))) - // Try to match attribute reference with name "a" with qualifier "ns1.ns2". - Seq(Seq("ns2", "a"), Seq("ns1", "ns2", "a")).foreach { nameParts => + // Try to match attribute reference with name "a" with qualifier "ns1.ns2.t1". + Seq(Seq("t1", "a"), Seq("ns2", "t1", "a"), Seq("ns1", "ns2", "t1", "a")).foreach { nameParts => attrs.resolve(nameParts, resolver) match { case Some(attr) => assert(attr.semanticEquals(attrs(0))) case _ => fail() } } - // Resolution is ambiguous. + // Non-matching cases. + Seq(Seq("ns1", "ns2", "t1"), Seq("ns2", "a")).foreach { nameParts => + val resolved = attrs.resolve(nameParts, resolver) + assert(resolved.isEmpty) + } + } + + test("attribute resolution ambiguity at the attribute name level") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t1")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2", "t2"))) + val ex = intercept[AnalysisException] { attrs.resolve(Seq("a"), resolver) } assert(ex.getMessage.contains( - "Reference 'a' is ambiguous, could be: ns1.ns2.a, ns1.ns2.ns3.a.")) + "Reference 'a' is ambiguous, could be: ns1.t1.a, ns1.ns2.t2.a.")) + } - // Non-matching cases. - Seq(Seq("ns1", "ns2"), Seq("ns1", "a")).foreach { nameParts => - val resolved = attrs.resolve(nameParts, resolver) - assert(resolved.isEmpty) + test("attribute resolution ambiguity at the qualifier level") { + val attrs = Seq( + AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t")), + AttributeReference("a", IntegerType)(qualifier = Seq("ns2", "ns1", "t"))) + + val ex = intercept[AnalysisException] { + attrs.resolve(Seq("ns1", "t", "a"), resolver) } + assert(ex.getMessage.contains( + "Reference 'ns1.t.a' is ambiguous, could be: ns1.t.a, ns2.ns1.t.a.")) } test("attribute resolution with nested fields") { val attrType = StructType(Seq(StructField("aa", IntegerType), StructField("bb", IntegerType))) - val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "ns2"))) + val attrs = Seq(AttributeReference("a", attrType)(qualifier = Seq("ns1", "t"))) - val resolved = attrs.resolve(Seq("ns1", "ns2", "a", "aa"), resolver) + val resolved = attrs.resolve(Seq("ns1", "t", "a", "aa"), resolver) resolved match { case Some(Alias(_, name)) => assert(name == "aa") case _ => fail() } val ex = intercept[AnalysisException] { - attrs.resolve(Seq("ns1", "ns2", "a", "cc"), resolver) + attrs.resolve(Seq("ns1", "t", "a", "cc"), resolver) } assert(ex.getMessage.contains("No such struct field cc in aa, bb")) } test("attribute resolution with case insensitive resolver") { - val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2"))) - attrs.resolve(Seq("Ns1", "nS2", "A"), caseInsensitiveResolution) match { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t"))) + attrs.resolve(Seq("Ns1", "T", "A"), caseInsensitiveResolution) match { case Some(attr) => assert(attr.semanticEquals(attrs(0)) && attr.name == "A") case _ => fail() } } test("attribute resolution with case sensitive resolver") { - val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "ns2"))) - assert(attrs.resolve(Seq("Ns1", "nS2", "A"), caseSensitiveResolution).isEmpty) - assert(attrs.resolve(Seq("ns1", "ns2", "A"), caseSensitiveResolution).isEmpty) - attrs.resolve(Seq("ns1", "ns2", "a"), caseSensitiveResolution) match { + val attrs = Seq(AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t"))) + assert(attrs.resolve(Seq("Ns1", "T", "A"), caseSensitiveResolution).isEmpty) + assert(attrs.resolve(Seq("ns1", "t", "A"), caseSensitiveResolution).isEmpty) + attrs.resolve(Seq("ns1", "t", "a"), caseSensitiveResolution) match { case Some(attr) => assert(attr.semanticEquals(attrs(0))) case _ => fail() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 0795f378eaf9..e72b2e9b1b21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -437,7 +437,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { JObject( "product-class" -> JString(classOf[AliasIdentifier].getName), "name" -> "alias", - "namespace" -> "[ns1, ns2]")) + "qualifier" -> "[ns1, ns2]")) // Converts SubqueryAlias to JSON assertJSON( @@ -448,7 +448,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { "num-children" -> 1, "identifier" -> JObject("product-class" -> JString(classOf[AliasIdentifier].getName), "name" -> "t1", - "namespace" -> JArray(Nil)), + "qualifier" -> JArray(Nil)), "child" -> 0), JObject( "class" -> classOf[JsonTestTreeNode].getName, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index fbfcc567cd2f..7a9a7f52ff8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -20,24 +20,21 @@ package org.apache.spark.sql.connector.catalog import org.mockito.Mockito.{mock, when} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.StructType class CatalogV2UtilSuite extends SparkFunSuite { test("Load relation should encode the identifiers for V2Relations") { val testCatalog = mock(classOf[TableCatalog]) - val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + val ident = mock(classOf[Identifier]) val table = mock(classOf[Table]) when(table.schema()).thenReturn(mock(classOf[StructType])) when(testCatalog.loadTable(ident)).thenReturn(table) val r = CatalogV2Util.loadRelation(testCatalog, ident) assert(r.isDefined) - r.get match { - case SubqueryAlias(_, v2Relation: DataSourceV2Relation) => - assert(v2Relation.catalog.exists(_ == testCatalog)) - assert(v2Relation.identifier.exists(_ == ident)) - case _ => fail() - } + assert(r.get.isInstanceOf[DataSourceV2Relation]) + val v2Relation = r.get.asInstanceOf[DataSourceV2Relation] + assert(v2Relation.catalog.exists(_ == testCatalog)) + assert(v2Relation.identifier.exists(_ == ident)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 0aff678e7944..cb6922eb3ca3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -726,7 +726,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed3, expected3) } else { parsed1 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.setProperty("test", "test"), TableChange.setProperty("comment", "new_comment"))) @@ -734,7 +734,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed2 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.removeProperty("comment"), TableChange.removeProperty("test"))) @@ -742,7 +742,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed3 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.removeProperty("comment"), TableChange.removeProperty("test"))) @@ -785,7 +785,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.setProperty("a", "1"), TableChange.setProperty("b", "0.1"), @@ -809,7 +809,7 @@ class PlanResolutionSuite extends AnalysisTest { comparePlans(parsed, expected) } else { parsed match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq(TableChange.setProperty("location", "new location"))) case _ => fail("Expect AlterTable, but got:\n" + parsed.treeString) } @@ -1058,14 +1058,14 @@ class PlanResolutionSuite extends AnalysisTest { val parsed3 = parseAndResolve(sql3) parsed1 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.updateColumnType(Array("i"), LongType))) case _ => fail("expect AlterTable") } parsed2 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.updateColumnType(Array("i"), LongType), TableChange.updateColumnComment(Array("i"), "new comment"))) @@ -1073,7 +1073,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed3 match { - case AlterTable(_, _, AsDataSourceV2Relation(_), changes) => + case AlterTable(_, _, _: DataSourceV2Relation, changes) => assert(changes == Seq( TableChange.updateColumnComment(Array("i"), "new comment"))) case _ => fail("expect AlterTable") @@ -1109,7 +1109,7 @@ class PlanResolutionSuite extends AnalysisTest { val catlogIdent = if (isSessionCatlog) v2SessionCatalog else testCat val tableIdent = if (isSessionCatlog) "v2Table" else "tab" parsed match { - case AlterTable(_, _, AsDataSourceV2Relation(r), _) => + case AlterTable(_, _, r: DataSourceV2Relation, _) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) case Project(_, AsDataSourceV2Relation(r)) => From 6fb5799958171f3a1dfcb64adc476c58aaf6d124 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 31 Jan 2020 11:15:15 -0800 Subject: [PATCH 07/11] refinement to minimize diff --- .../sql/execution/command/PlanResolutionSuite.scala | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index f2ffcc39ecfa..88f30353cce9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -957,9 +957,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed2 match { case UpdateTable( - SubqueryAlias( - AliasIdentifier("t", Seq()), - AsDataSourceV2Relation(_)), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), None) => @@ -971,9 +969,7 @@ class PlanResolutionSuite extends AnalysisTest { parsed3 match { case UpdateTable( - SubqueryAlias( - AliasIdentifier("t", Seq()), - AsDataSourceV2Relation(_)), + SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(name: UnresolvedAttribute, StringLiteral("Robert")), Assignment(age: UnresolvedAttribute, IntegerLiteral(32))), Some(EqualTo(p: UnresolvedAttribute, IntegerLiteral(1)))) => @@ -985,9 +981,7 @@ class PlanResolutionSuite extends AnalysisTest { } parsed4 match { - case UpdateTable( - SubqueryAlias(AliasIdentifier("t", Seq()), - AsDataSourceV2Relation(_)), + case UpdateTable(SubqueryAlias(AliasIdentifier("t", Seq()), AsDataSourceV2Relation(_)), Seq(Assignment(key: UnresolvedAttribute, IntegerLiteral(32))), Some(InSubquery(values, query))) => assert(key.name == "t.age") From b69e91a12f0ec5ea100cf665a7f6648aa610b8ae Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 3 Feb 2020 10:32:33 -0800 Subject: [PATCH 08/11] address PR comments --- .../sql/catalyst/analysis/Analyzer.scala | 16 +-- .../sql/catalyst/expressions/package.scala | 118 ++++++++++++++++-- 2 files changed, 115 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9015daee299b..56cc2a274bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -823,9 +823,7 @@ class Analyzer( .getOrElse(alter) case u: UnresolvedV2Relation => - CatalogV2Util.loadRelation(u.catalog, u.tableName) - .map(SubqueryAlias(u.originalNameParts, _)) - .getOrElse(u) + CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u) } /** @@ -889,13 +887,7 @@ class Analyzer( } case u: UnresolvedRelation => - lookupRelation(u.multipartIdentifier) - .map { - case r: DataSourceV2Relation => SubqueryAlias(u.multipartIdentifier, r) - case other => other - } - .map(resolveViews) - .getOrElse(u) + lookupRelation(u.multipartIdentifier).map(resolveViews).getOrElse(u) case u @ UnresolvedTable(identifier) => lookupTableOrView(identifier).map { @@ -932,7 +924,9 @@ class Analyzer( case v1Table: V1Table => v1SessionCatalog.getRelation(v1Table.v1Table) case table => - DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + SubqueryAlias( + identifier, + DataSourceV2Relation.create(table, Some(catalog), Some(ident))) } val key = catalog.name +: ident.namespace :+ ident.name Option(AnalysisContext.get.relationCache.getOrElseUpdate(key, loaded.orNull)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 10c2aa77f937..247136bfe063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -127,13 +127,102 @@ package object expressions { m.mapValues(_.distinct).map(identity) } - /** Attribute name to attributes */ - @transient private val attrsMap: Map[String, Seq[Attribute]] = { + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) } - /** Perform attribute resolution given a name and a resolver. */ - def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + /** Map to use for qualified case insensitive attribute lookups with 2 part key */ + @transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = { + // key is 2 part: table/alias and name + val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy { + a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Map to use for qualified case insensitive attribute lookups with 3 part key */ + @transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = { + // key is 3 part: database name, table name and name + val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a => + (a.qualifier.head.toLowerCase(Locale.ROOT), + a.qualifier.last.toLowerCase(Locale.ROOT), + a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Returns true if all qualifiers in `attrs` have 2 or less parts. */ + @transient private val has2OrLessPartQualifiers: Boolean = attrs.forall(_.qualifier.length <= 2) + + /** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */ + private def matchWith2OrLessPartQualifiers( + nameParts: Seq[String], + resolver: Resolver): (Seq[Attribute], Seq[String]) = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st two parts are qualifier + // (i.e. database name and table name) and the 3rd part is the actual column name. + // + // For example, consider an example where "db1" is the database name, "a" is the table name + // and "b" is the column name and "c" is the struct field name. + // If the name parts is db1.a.b.c, then Attribute will match + // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + var matches: (Seq[Attribute], Seq[String]) = nameParts match { + case dbPart +: tblPart +: name +: nestedFields => + val key = (dbPart.toLowerCase(Locale.ROOT), + tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified3Part.get(key)).filter { + a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last)) + } + (attributes, nestedFields) + case _ => + (Seq.empty, Seq.empty) + } + + // If there are no matches, then find matches for the given name assuming that + // the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the + // 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + if (matches._1.isEmpty) { + matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.last) + } + (attributes, nestedFields) + case _ => + (Seq.empty[Attribute], Seq.empty[String]) + } + } + + // If none of attributes match database.table.column pattern or + // `table.column` pattern, we try to resolve it as a column. + matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + } + + /** + * Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts. + */ + private def matchWith3OrMorePartQualifiers( + nameParts: Seq[String], + resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Returns true if the `short` qualifier is a subset of the last elements of // `long` qualifier. For example, Seq("a", "b") is a subset of Seq("a", "a", "b"), // but not a subset of Seq("a", "b", "b"). @@ -171,7 +260,7 @@ package object expressions { val matched = collectMatches( name, nameParts.take(index), - attrsMap.get(name.toLowerCase(Locale.ROOT))) + direct.get(name.toLowerCase(Locale.ROOT))) if (matched.nonEmpty) { (matched, nameParts.takeRight(nameParts.length - index - 1)) :: Nil } else { @@ -179,12 +268,25 @@ package object expressions { } } - if (matches.isEmpty) { - return None + if (!matches.hasNext) { + return (Nil, Nil) } // Note that `matches` is an iterator, and only the first match will be used. - val (candidates, nestedFields) = matches.next + if (matches.hasNext) { + matches.next + } else { + (Nil, Nil) + } + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + val (candidates, nestedFields) = if (has2OrLessPartQualifiers) { + matchWith2OrLessPartQualifiers(nameParts, resolver) + } else { + matchWith3OrMorePartQualifiers(nameParts, resolver) + } def name = UnresolvedAttribute(nameParts).name candidates match { From 15c7003d7e5eb1a41b9bff34aa2c120d03ff6b95 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 3 Feb 2020 10:37:08 -0800 Subject: [PATCH 09/11] rename functions --- .../spark/sql/catalyst/expressions/package.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 247136bfe063..1cdd6c2d0b77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -153,10 +153,11 @@ package object expressions { } /** Returns true if all qualifiers in `attrs` have 2 or less parts. */ - @transient private val has2OrLessPartQualifiers: Boolean = attrs.forall(_.qualifier.length <= 2) + @transient private val hasTwoOrLessPartQualifiers: Boolean = + attrs.forall(_.qualifier.length <= 2) /** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */ - private def matchWith2OrLessPartQualifiers( + private def matchWithTwoOrLessPartQualifiers( nameParts: Seq[String], resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Collect matching attributes given a name and a lookup. @@ -220,7 +221,7 @@ package object expressions { /** * Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts. */ - private def matchWith3OrMorePartQualifiers( + private def matchWithThreeOrMorePartQualifiers( nameParts: Seq[String], resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Returns true if the `short` qualifier is a subset of the last elements of @@ -282,10 +283,10 @@ package object expressions { /** Perform attribute resolution given a name and a resolver. */ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { - val (candidates, nestedFields) = if (has2OrLessPartQualifiers) { - matchWith2OrLessPartQualifiers(nameParts, resolver) + val (candidates, nestedFields) = if (hasTwoOrLessPartQualifiers) { + matchWithTwoOrLessPartQualifiers(nameParts, resolver) } else { - matchWith3OrMorePartQualifiers(nameParts, resolver) + matchWithThreeOrMorePartQualifiers(nameParts, resolver) } def name = UnresolvedAttribute(nameParts).name From 24daf5f8d38cdb68ba940f243585be9241652cd9 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 4 Feb 2020 14:05:06 -0800 Subject: [PATCH 10/11] address PR comments --- .../sql/catalyst/expressions/package.scala | 36 ++++++++----------- .../spark/sql/catalyst/identifiers.scala | 2 +- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1cdd6c2d0b77..2e8af389e48a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -162,9 +162,9 @@ package object expressions { resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Collect matching attributes given a name and a lookup. def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { - candidates.toSeq.flatMap(_.collect { + candidates.getOrElse(Nil).collect { case a if resolver(a.name, name) => a.withName(name) - }) + } } // Find matches for the given name assuming that the 1st two parts are qualifier @@ -242,10 +242,10 @@ package object expressions { name: String, qualifier: Seq[String], candidates: Option[Seq[Attribute]]): Seq[Attribute] = { - candidates.toSeq.flatMap(_.collect { + candidates.getOrElse(Nil).collect { case a if resolver(name, a.name) && matchQualifier(qualifier, a.qualifier) => a.withName(name) - }) + } } // Iterate each string in `nameParts` in a reverse order and try to match the attributes @@ -257,28 +257,22 @@ package object expressions { // Note that the match is performed in the reverse order in order to match the longest // qualifier as possible. If a match is found, the remaining portion of `nameParts` // is also returned as nested fields. - val matches = nameParts.zipWithIndex.reverseIterator.flatMap { case (name, index) => - val matched = collectMatches( + var candidates: Seq[Attribute] = Nil + var nestedFields: Seq[String] = Nil + var i = nameParts.length - 1 + while (i >= 0 && candidates.isEmpty) { + val name = nameParts(i) + candidates = collectMatches( name, - nameParts.take(index), + nameParts.take(i), direct.get(name.toLowerCase(Locale.ROOT))) - if (matched.nonEmpty) { - (matched, nameParts.takeRight(nameParts.length - index - 1)) :: Nil - } else { - Nil + if (candidates.nonEmpty) { + nestedFields = nameParts.takeRight(nameParts.length - i - 1) } + i -= 1 } - if (!matches.hasNext) { - return (Nil, Nil) - } - - // Note that `matches` is an iterator, and only the first match will be used. - if (matches.hasNext) { - matches.next - } else { - (Nil, Nil) - } + (candidates, nestedFields) } /** Perform attribute resolution given a name and a resolver. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 460a2db41a93..c574a20da0b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -49,7 +49,7 @@ sealed trait IdentifierWithDatabase { /** * Encapsulates an identifier that is either a alias name or an identifier that has table - * name and a namespace. + * name and a qualifier. * The SubqueryAlias node keeps track of the qualifier using the information in this structure * @param name - Is an alias name or a table name * @param qualifier - Is a qualifier From c40895ff5308e3cbd9b35880a988ef9ef61ba578 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 4 Feb 2020 20:24:36 -0800 Subject: [PATCH 11/11] address PR comments --- .../sql/catalyst/expressions/package.scala | 12 +++++------ .../AttributeResolutionSuite.scala | 21 ++++++++++++++++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 2e8af389e48a..9f42e643e4cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -153,11 +153,11 @@ package object expressions { } /** Returns true if all qualifiers in `attrs` have 2 or less parts. */ - @transient private val hasTwoOrLessPartQualifiers: Boolean = + @transient private val hasTwoOrLessQualifierParts: Boolean = attrs.forall(_.qualifier.length <= 2) /** Match attributes for the case where all qualifiers in `attrs` have 2 or less parts. */ - private def matchWithTwoOrLessPartQualifiers( + private def matchWithTwoOrLessQualifierParts( nameParts: Seq[String], resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Collect matching attributes given a name and a lookup. @@ -221,7 +221,7 @@ package object expressions { /** * Match attributes for the case where at least one qualifier in `attrs` has more than 2 parts. */ - private def matchWithThreeOrMorePartQualifiers( + private def matchWithThreeOrMoreQualifierParts( nameParts: Seq[String], resolver: Resolver): (Seq[Attribute], Seq[String]) = { // Returns true if the `short` qualifier is a subset of the last elements of @@ -277,10 +277,10 @@ package object expressions { /** Perform attribute resolution given a name and a resolver. */ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { - val (candidates, nestedFields) = if (hasTwoOrLessPartQualifiers) { - matchWithTwoOrLessPartQualifiers(nameParts, resolver) + val (candidates, nestedFields) = if (hasTwoOrLessQualifierParts) { + matchWithTwoOrLessQualifierParts(nameParts, resolver) } else { - matchWithThreeOrMorePartQualifiers(nameParts, resolver) + matchWithThreeOrMoreQualifierParts(nameParts, resolver) } def name = UnresolvedAttribute(nameParts).name diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala index 8ef0baa039e1..813a68f68451 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeResolutionSuite.scala @@ -38,13 +38,28 @@ class AttributeResolutionSuite extends SparkFunSuite { } } - // Non-matching cases. + // Non-matching cases Seq(Seq("ns1", "ns2", "t1"), Seq("ns2", "a")).foreach { nameParts => - val resolved = attrs.resolve(nameParts, resolver) - assert(resolved.isEmpty) + assert(attrs.resolve(nameParts, resolver).isEmpty) } } + test("attribute resolution where table and attribute names are the same") { + val attrs = Seq(AttributeReference("t", IntegerType)(qualifier = Seq("ns1", "ns2", "t"))) + // Matching cases + Seq( + Seq("t"), Seq("t", "t"), Seq("ns2", "t", "t"), Seq("ns1", "ns2", "t", "t") + ).foreach { nameParts => + attrs.resolve(nameParts, resolver) match { + case Some(attr) => assert(attr.semanticEquals(attrs(0))) + case _ => fail() + } + } + + // Non-matching case + assert(attrs.resolve(Seq("ns1", "ns2", "t"), resolver).isEmpty) + } + test("attribute resolution ambiguity at the attribute name level") { val attrs = Seq( AttributeReference("a", IntegerType)(qualifier = Seq("ns1", "t1")),