diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala index a8832aada083..271e151e709c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ Project, ReplaceTable, Union, + UnionLoop, Unpivot } import org.apache.spark.sql.catalyst.rules.Rule @@ -49,6 +50,7 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure +import org.apache.spark.sql.errors.DataTypeErrors.cannotMergeIncompatibleDataTypesError import org.apache.spark.sql.types.DataType abstract class TypeCoercionBase extends TypeCoercionHelper { @@ -247,6 +249,25 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { val attrMapping = s.children.head.output.zip(newChildren.head.output) s.copy(children = newChildren) -> attrMapping } + + case s: UnionLoop + if s.childrenResolved && s.anchor.output.length == s.recursion.output.length + && !s.resolved => + // If the anchor data type is wider than the recursion data type, we cast the recursion + // type to match the anchor type. + // On the other hand, we cannot cast the anchor type into a wider recursion type, as at + // this point the UnionLoopRefs inside the recursion are already resolved with the + // narrower anchor type. + val projectList = s.recursion.output.zip(s.anchor.output.map(_.dataType)).map { + case (attr, dt) => + val widerType = findWiderTypeForTwo(attr.dataType, dt) + if (widerType.isDefined && widerType.get == dt) { + Alias(Cast(attr, dt), attr.name)() + } else { + throw cannotMergeIncompatibleDataTypesError(dt, attr.dataType) + } + } + s.copy(recursion = Project(projectList, s.recursion)) -> Nil } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8a87cdcbb0fb..3f3a60e19641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -405,8 +405,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } } val resultAttrMapping = if (canGetOutput(plan)) { - // We propagate the attributes mapping to the parent plan node to update attributes, so - // the `newAttr` must be part of this plan's output. (transferAttrMapping ++ newOtherAttrMapping).filter { case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ba56e8599e56..5215e8a9568b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -546,6 +546,23 @@ abstract class UnionBase extends LogicalPlan { .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) .reduce(merge(_, _)) } + + + + + /** + * Checks whether the child outputs are compatible by using `DataType.equalsStructurally`. Do + * that by comparing the size of the output with the size of the first child's output and by + * comparing output data types with the data types of the first child's output. + * + * This method needs to be evaluated after `childrenResolved`. + */ + def allChildrenCompatible: Boolean = childrenResolved && children.tail.forall { child => + child.output.length == children.head.output.length && + child.output.zip(children.head.output).forall { + case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true) + } + } } /** @@ -606,20 +623,6 @@ case class Union( children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible } - /** - * Checks whether the child outputs are compatible by using `DataType.equalsStructurally`. Do - * that by comparing the size of the output with the size of the first child's output and by - * comparing output data types with the data types of the first child's output. - * - * This method needs to be evaluated after `childrenResolved`. - */ - def allChildrenCompatible: Boolean = childrenResolved && children.tail.forall { child => - child.output.length == children.head.output.length && - child.output.zip(children.head.output).forall { - case (l, r) => DataType.equalsStructurally(l.dataType, r.dataType, true) - } - } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union = copy(children = newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala index 072aa4540775..cea342d37c06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala @@ -59,6 +59,11 @@ case class UnionLoop( id.toString + limit.map(", " + _.toString).getOrElse("") + maxDepth.map(", " + _.toString).getOrElse("") } + + override lazy val resolved: Boolean = { + // allChildrenCompatible needs to be evaluated after childrenResolved + childrenResolved && allChildrenCompatible + } } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index cc04d51193d3..a252fb223788 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1830,6 +1830,49 @@ SELECT val FROM randoms LIMIT 5 [Analyzer test output redacted due to nondeterminism] +-- !query +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, CAST(1 AS BIGINT) + UNION ALL + SELECT n+1, n+1 FROM t1 WHERE n < 5) +SELECT * FROM t1 +-- !query analysis +WithCTE +:- CTERelationDef xxxx, false +: +- SubqueryAlias t1 +: +- Project [1#x AS n#x, CAST(1 AS BIGINT)#xL AS m#xL] +: +- UnionLoop xxxx +: :- Project [1 AS 1#x, cast(1 as bigint) AS CAST(1 AS BIGINT)#xL] +: : +- OneRowRelation +: +- Project [cast((n + 1)#x as int) AS (n + 1)#x, cast((n + 1)#x as bigint) AS (n + 1)#xL] +: +- Project [(n#x + 1) AS (n + 1)#x, (n#x + 1) AS (n + 1)#x] +: +- Filter (n#x < 5) +: +- SubqueryAlias t1 +: +- Project [1#x AS n#x, CAST(1 AS BIGINT)#xL AS m#xL] +: +- UnionLoopRef xxxx, [1#x, CAST(1 AS BIGINT)#xL], false ++- Project [n#x, m#xL] + +- SubqueryAlias t1 + +- CTERelationRef xxxx, true, [n#x, m#xL], false, false + + +-- !query +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, 1 + UNION ALL + SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5) +SELECT * FROM t1 +-- !query analysis +org.apache.spark.SparkException +{ + "errorClass" : "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "left" : "\"INT\"", + "right" : "\"BIGINT\"" + } +} + + -- !query WITH RECURSIVE t1(n) AS ( SELECT 1 diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index fc7306dc8b68..b03d853f2f09 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -683,6 +683,20 @@ WITH RECURSIVE randoms(val) AS ( ) SELECT val FROM randoms LIMIT 5; +-- Type coercion where the anchor is wider +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, CAST(1 AS BIGINT) + UNION ALL + SELECT n+1, n+1 FROM t1 WHERE n < 5) +SELECT * FROM t1; + +-- Type coercion where the recursion is wider +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, 1 + UNION ALL + SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5) +SELECT * FROM t1; + -- Recursive CTE with nullable recursion and non-recursive anchor WITH RECURSIVE t1(n) AS ( SELECT 1 diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index f6dcadb326c4..e41efc7b78e5 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1675,6 +1675,42 @@ struct> [4,5,1,2,3] +-- !query +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, CAST(1 AS BIGINT) + UNION ALL + SELECT n+1, n+1 FROM t1 WHERE n < 5) +SELECT * FROM t1 +-- !query schema +struct +-- !query output +1 1 +2 2 +3 3 +4 4 +5 5 + + +-- !query +WITH RECURSIVE t1(n, m) AS ( + SELECT 1, 1 + UNION ALL + SELECT n+1, CAST(n+1 AS BIGINT) FROM t1 WHERE n < 5) +SELECT * FROM t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +{ + "errorClass" : "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "left" : "\"INT\"", + "right" : "\"BIGINT\"" + } +} + + -- !query WITH RECURSIVE t1(n) AS ( SELECT 1