diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba6764444bdf..95e2ddd40af1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1796,6 +1796,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case s: Sort if !s.resolved || s.missingInput.nonEmpty => resolveReferencesInSort(s) + case u: UnresolvedWithCTERelations => + UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 2982d8477fcc..ff0dbcd7ef15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Command, CTEInChildren, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.TypeUtils._ +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_POLICY @@ -272,7 +272,8 @@ object CTESubstitution extends Rule[LogicalPlan] { alwaysInline: Boolean, cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = { plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION)) { + _.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION, + UNRESOLVED_IDENTIFIER)) { case RelationTimeTravel(UnresolvedRelation(Seq(table), _, _), _, _) if cteRelations.exists(r => plan.conf.resolver(r._1, table)) => throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table)) @@ -287,6 +288,14 @@ object CTESubstitution extends Rule[LogicalPlan] { } }.getOrElse(u) + case p: PlanWithUnresolvedIdentifier => + // We must look up CTE relations first when resolving `UnresolvedRelation`s, + // but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node + // and may produce `UnresolvedRelation` later. + // Here we wrap it with `UnresolvedWithCTERelations` so that we can + // delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved. + UnresolvedWithCTERelations(p, cteRelations) + case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index f04b7799e35e..e0142c445ae8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER +import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE} import org.apache.spark.sql.types.StringType /** @@ -35,9 +35,18 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch] } override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { + _.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved => executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr))) + case u @ UnresolvedWithCTERelations(p, cteRelations) => + this.apply(p) match { + case u @ UnresolvedRelation(Seq(table), _, _) => + cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) => + // Add a `SubqueryAlias` for hint-resolving rules to match relation names. + SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming)) + }.getOrElse(u) + case other => other + } case other => other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a2cab60b392b..abb7e7956f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -65,6 +65,17 @@ case class PlanWithUnresolvedIdentifier( final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER) } +/** + * A logical plan placeholder which delays CTE resolution + * to moment when PlanWithUnresolvedIdentifier gets resolved + */ +case class UnresolvedWithCTERelations( + unresolvedPlan: LogicalPlan, + cteRelations: Seq[(String, CTERelationDef)]) + extends UnresolvedLeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE) +} + /** * An expression placeholder that holds the identifier clause string expression. It will be * replaced by the actual expression with the evaluated identifier string. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index c5cc1eaf8f05..6258bd615b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -151,6 +151,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value val UNRESOLVED_WINDOW_EXPRESSION: Value = Value + val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index b3e2cd5ada95..f0bf8b883dd8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -985,6 +985,79 @@ DropTable false, false +- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 +-- !query +DECLARE agg = 'max' +-- !query analysis +CreateVariable defaultvalueexpression(max, 'max'), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.agg + + +-- !query +DECLARE col = 'c1' +-- !query analysis +CreateVariable defaultvalueexpression(c1, 'c1'), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.col + + +-- !query +DECLARE tab = 'T' +-- !query analysis +CreateVariable defaultvalueexpression(T, 'T'), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.tab + + +-- !query +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab) +-- !query analysis +WithCTE +:- CTERelationDef xxxx, false +: +- SubqueryAlias S +: +- Project [col1#x AS c1#x, col2#x AS c2#x] +: +- LocalRelation [col1#x, col2#x] +:- CTERelationDef xxxx, false +: +- SubqueryAlias T +: +- Project [col1#x AS c1#x, col2#x AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- Aggregate [max(c1#x) AS max(c1)#x] + +- SubqueryAlias T + +- CTERelationRef xxxx, true, [c1#x, c2#x], false + + +-- !query +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T') +-- !query analysis +WithCTE +:- CTERelationDef xxxx, false +: +- SubqueryAlias S +: +- Project [col1#x AS c1#x, col2#x AS c2#x] +: +- LocalRelation [col1#x, col2#x] +:- CTERelationDef xxxx, false +: +- SubqueryAlias T +: +- Project [col1#x AS c1#x, col2#x AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- Aggregate [max(c1#x) AS max(c1)#x] + +- SubqueryAlias T + +- CTERelationRef xxxx, true, [c1#x, c2#x], false + + +-- !query +WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3)) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC') +-- !query analysis +WithCTE +:- CTERelationDef xxxx, false +: +- SubqueryAlias ABC +: +- Project [col1#x AS c1#x, col2#x AS c2#x] +: +- LocalRelation [col1#x, col2#x] ++- Aggregate [max(c1#x) AS max(c1)#x] + +- SubqueryAlias ABC + +- CTERelationRef xxxx, true, [c1#x, c2#x], false + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql index 46461dcd048e..4aa8019097fd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql @@ -141,6 +141,22 @@ drop view v1; drop table t1; drop table t2; +-- SPARK-46625: CTE reference with identifier clause and session variables +DECLARE agg = 'max'; +DECLARE col = 'c1'; +DECLARE tab = 'T'; + +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab); + +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T'); + +WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3)) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC'); + -- Not supported SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1); SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1')); diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 2aa809324a76..952fb8fdc2bd 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1115,6 +1115,59 @@ struct<> +-- !query +DECLARE agg = 'max' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE col = 'c1' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE tab = 'T' +-- !query schema +struct<> +-- !query output + + + +-- !query +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab) +-- !query schema +struct +-- !query output +c + + +-- !query +WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)), + T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd')) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T') +-- !query schema +struct +-- !query output +c + + +-- !query +WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3)) +SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC') +-- !query schema +struct +-- !query output +2 + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query schema