diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6665d885554fc..3c995573d53d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -509,19 +509,21 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe /** * Split the plan for a scalar subquery into the parts above the innermost query block * (first part of returned value), the HAVING clause of the innermost query block - * (optional second part) and the parts below the HAVING CLAUSE (third part). + * (optional second part) and the Aggregate below the HAVING CLAUSE (optional third part). + * When the third part is empty, it means the subquery is a non-aggregated single-row subquery. */ - private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { + private def splitSubquery( + plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], Option[Aggregate]) = { val topPart = ArrayBuffer.empty[LogicalPlan] var bottomPart: LogicalPlan = plan while (true) { bottomPart match { case havingPart @ Filter(_, aggPart: Aggregate) => - return (topPart.toSeq, Option(havingPart), aggPart) + return (topPart.toSeq, Option(havingPart), Some(aggPart)) case aggPart: Aggregate => // No HAVING clause - return (topPart.toSeq, None, aggPart) + return (topPart.toSeq, None, Some(aggPart)) case p @ Project(_, child) => topPart += p @@ -531,6 +533,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe topPart += s bottomPart = child + case p: LogicalPlan if p.maxRows.exists(_ <= 1) => + // Non-aggregated one row subquery. + return (topPart.toSeq, None, None) + case Filter(_, op) => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op, " below filter") @@ -561,72 +567,80 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) + lazy val planWithoutCountBug = Project( + currentChild.output :+ origOutput, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + if (resultWithZeroTups.isEmpty) { // CASE 1: Subquery guaranteed not to have the COUNT bug - Project( - currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + planWithoutCountBug } else { - // Subquery might have the COUNT bug. Add appropriate corrections. val (topPart, havingNode, aggNode) = splitSubquery(query) - - // The next two cases add a leading column to the outer join input to make it - // possible to distinguish between the case when no tuples join and the case - // when the tuple that joins contains null values. - // The leading column always has the value TRUE. - val alwaysTrueExprId = NamedExpression.newExprId - val alwaysTrueExpr = Alias(Literal.TrueLiteral, - ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) - val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, - BooleanType)(exprId = alwaysTrueExprId) - - val aggValRef = query.output.head - - if (havingNode.isEmpty) { - // CASE 2: Subquery with no HAVING clause - val subqueryResultExpr = - Alias(If(IsNull(alwaysTrueRef), - resultWithZeroTups.get, - aggValRef), origOutput.name)() - subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute)) - Project( - currentChild.output :+ subqueryResultExpr, - Join(currentChild, - Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - + if (aggNode.isEmpty) { + // SPARK-40862: When the aggregate node is empty, it means the subquery produces + // at most one row and it is not subject to the COUNT bug. + planWithoutCountBug } else { - // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. - // Need to modify any operators below the join to pass through all columns - // referenced in the HAVING clause. - var subqueryRoot: UnaryNode = aggNode - val havingInputs: Seq[NamedExpression] = aggNode.output - - topPart.reverse.foreach { - case Project(projList, _) => - subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s @ SubqueryAlias(alias, _) => - subqueryRoot = SubqueryAlias(alias, subqueryRoot) - case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op) + // Subquery might have the COUNT bug. Add appropriate corrections. + val aggregate = aggNode.get + + // The next two cases add a leading column to the outer join input to make it + // possible to distinguish between the case when no tuples join and the case + // when the tuple that joins contains null values. + // The leading column always has the value TRUE. + val alwaysTrueExprId = NamedExpression.newExprId + val alwaysTrueExpr = Alias(Literal.TrueLiteral, + ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) + val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, + BooleanType)(exprId = alwaysTrueExprId) + + val aggValRef = query.output.head + + if (havingNode.isEmpty) { + // CASE 2: Subquery with no HAVING clause + val subqueryResultExpr = + Alias(If(IsNull(alwaysTrueRef), + resultWithZeroTups.get, + aggValRef), origOutput.name)() + subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute)) + Project( + currentChild.output :+ subqueryResultExpr, + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) + + } else { + // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. + // Need to modify any operators below the join to pass through all columns + // referenced in the HAVING clause. + var subqueryRoot: UnaryNode = aggregate + val havingInputs: Seq[NamedExpression] = aggregate.output + + topPart.reverse.foreach { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s@SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op => throw QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op) + } + + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), resultWithZeroTups.get), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)() + + subqueryAttrMapping += ((origOutput, caseExpr.toAttribute)) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } - - // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups - // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) - // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), resultWithZeroTups.get), - (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), - aggValRef), - origOutput.name)() - - subqueryAttrMapping += ((origOutput, caseExpr.toAttribute)) - - Project( - currentChild.output :+ caseExpr, - Join(currentChild, - Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4b58635636771..7b67648d4752a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2491,4 +2491,21 @@ class SubquerySuite extends QueryTest Row("a")) } } + + test("SPARK-40862: correlated one-row subquery with non-deterministic expressions") { + import org.apache.spark.sql.functions.udf + withTempView("t1") { + sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a") + val func = udf(() => "a") + spark.udf.register("func", func.asNondeterministic()) + checkAnswer(sql( + """ + |SELECT ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] || str AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank, func() AS str) + |) FROM t1 + |""".stripMargin), + Row("aa")) + } + } }