diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index 30619f21bb8f..f24227abbb65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH} +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_IDENTIFIER_WITH_CTE, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.DataType @@ -189,7 +189,8 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase { // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE // relations are not children of `UnresolvedWith`. case NameParameterizedQuery(child, argNames, argValues) - if !child.containsPattern(UNRESOLVED_WITH) && argValues.forall(_.resolved) => + if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) && + argValues.forall(_.resolved) => if (argNames.length != argValues.length) { throw SparkException.internalError(s"The number of argument names ${argNames.length} " + s"must be equal to the number of argument values ${argValues.length}.") @@ -199,7 +200,8 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase { bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } case PosParameterizedQuery(child, args) - if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) => + if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) && + args.forall(_.resolved) => val indexedArgs = args.zipWithIndex checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index c90b34d45e78..791bcc91d509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -741,4 +741,21 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { Row("c1")) } } + + test("SPARK-50322: parameterized identifier in a sub-query") { + withTable("tt1") { + sql("CREATE TABLE tt1 (c1 INT)") + sql("INSERT INTO tt1 VALUES (1)") + def query(p: String): String = { + s""" + |WITH v1 AS ( + | SELECT * FROM tt1 + | WHERE 1 = (SELECT * FROM IDENTIFIER($p)) + |) SELECT * FROM v1""".stripMargin + } + + checkAnswer(spark.sql(query(":tab"), args = Map("tab" -> "tt1")), Row(1)) + checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1)) + } + } }