Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
resolveReferencesInSort(s)

case u: UnresolvedWithCTERelations =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, now we need to special-case UnresolvedWithCTERelations twice: once in ResolveReferences to resolve session variables and once in ResolveIdentifierClause to resolve identifier and look up CTE relations.

How about we make UnresolvedWithCTERelations an unary code, and only special case it once in ResolveRelations that we should look up from CTE relations for UnresolvedRelations insideUnresolvedWithCTERelations? Sorry for the back and forth!

Copy link
Contributor Author

@nebojsa-db nebojsa-db Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry!
Hm, issue with that approach is that ResolveRelations is traversing the tree in bottom up manner so we will first do table lookup instead of CTE relations lookup since it will first encounter UnresolvedRelation instead of UnresolvedWithCTERelations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, it's better to keep the bottom-up resolotion.

UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Command, CTEInChildren, CTERelationDef, CTERelationRef, InsertIntoDir, LogicalPlan, ParsedStatement, SubqueryAlias, UnresolvedWith, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_POLICY
Expand Down Expand Up @@ -272,7 +272,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION)) {
_.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION,
UNRESOLVED_IDENTIFIER)) {
case RelationTimeTravel(UnresolvedRelation(Seq(table), _, _), _, _)
if cteRelations.exists(r => plan.conf.resolver(r._1, table)) =>
throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))
Expand All @@ -287,6 +288,14 @@ object CTESubstitution extends Rule[LogicalPlan] {
}
}.getOrElse(u)

case p: PlanWithUnresolvedIdentifier =>
// We must look up CTE relations first when resolving `UnresolvedRelation`s,
// but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node
// and may produce `UnresolvedRelation` later.
// Here we wrap it with `UnresolvedWithCTERelations` so that we can
// delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved.
UnresolvedWithCTERelations(p, cteRelations)

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE}
import org.apache.spark.sql.types.StringType

/**
Expand All @@ -35,9 +35,18 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
_.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved =>
executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)))
case u @ UnresolvedWithCTERelations(p, cteRelations) =>
this.apply(p) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}.getOrElse(u)
case other => other
}
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand Down Expand Up @@ -65,6 +65,17 @@ case class PlanWithUnresolvedIdentifier(
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER)
}

/**
* A logical plan placeholder which delays CTE resolution
* to moment when PlanWithUnresolvedIdentifier gets resolved
*/
case class UnresolvedWithCTERelations(
unresolvedPlan: LogicalPlan,
cteRelations: Seq[(String, CTERelationDef)])
extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE)
}

/**
* An expression placeholder that holds the identifier clause string expression. It will be
* replaced by the actual expression with the evaluated identifier string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value

// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,79 @@ DropTable false, false
+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2


-- !query
DECLARE agg = 'max'
-- !query analysis
CreateVariable defaultvalueexpression(max, 'max'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.agg


-- !query
DECLARE col = 'c1'
-- !query analysis
CreateVariable defaultvalueexpression(c1, 'c1'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.col


-- !query
DECLARE tab = 'T'
-- !query analysis
CreateVariable defaultvalueexpression(T, 'T'), false
+- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.tab


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias S
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias T
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias T
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias S
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
:- CTERelationDef xxxx, false
: +- SubqueryAlias T
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias T
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
-- !query analysis
WithCTE
:- CTERelationDef xxxx, false
: +- SubqueryAlias ABC
: +- Project [col1#x AS c1#x, col2#x AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Aggregate [max(c1#x) AS max(c1)#x]
+- SubqueryAlias ABC
+- CTERelationRef xxxx, true, [c1#x, c2#x], false


-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query analysis
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,22 @@ drop view v1;
drop table t1;
drop table t2;

-- SPARK-46625: CTE reference with identifier clause and session variables
DECLARE agg = 'max';
DECLARE col = 'c1';
DECLARE tab = 'T';

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab);

WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T');

WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC');

-- Not supported
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1);
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,59 @@ struct<>



-- !query
DECLARE agg = 'max'
-- !query schema
struct<>
-- !query output



-- !query
DECLARE col = 'c1'
-- !query schema
struct<>
-- !query output



-- !query
DECLARE tab = 'T'
-- !query schema
struct<>
-- !query output



-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
-- !query schema
struct<max(c1):string>
-- !query output
c


-- !query
WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
-- !query schema
struct<max(c1):string>
-- !query output
c


-- !query
WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
-- !query schema
struct<max(c1):int>
-- !query output
2


-- !query
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query schema
Expand Down