diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9c97e1e9b441..dfd5510b52a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1325,25 +1325,50 @@ class Analyzer( * * Note : In this routine, the unresolved attributes are resolved from the input plan's * children attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @param trimAlias When true, trim unnecessary alias of `GetStructField`. Note that, + * we cannot trim the alias of top-level `GetStructField`, as we should + * resolve `UnresolvedAttribute` to a named expression. The caller side + * can trim the alias of top-level `GetStructField` if it's safe to do so. + * @return resolved Expression. */ - private def resolveExpressionTopDown(e: Expression, q: LogicalPlan): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver) - .orElse(resolveLiteralFunction(nameParts, u, q)) - .getOrElse(u) + private def resolveExpressionTopDown( + e: Expression, + q: LogicalPlan, + trimAlias: Boolean = false): Expression = { + + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val resolved = + withPosition(u) { + q.resolveChildren(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, q)) + .getOrElse(u) + } + val result = resolved match { + // As the comment of method `resolveExpressionTopDown`'s param `trimAlias` said, + // when trimAlias = true, we will trim unnecessary alias of `GetStructField` and + // we won't trim the alias of top-level `GetStructField`. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias of + // `GetStructField` here is safe. + case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s + case others => others } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - case _ => e.mapChildren(resolveExpressionTopDown(_, q)) + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } } + + innerResolve(e, isTopLevel = true) } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -1425,11 +1450,49 @@ class Analyzer( // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan - // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of - // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => - a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) + // SPARK-31607: Resolve Struct field in groupByExpressions and aggregateExpressions + // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with + // different ExprId. This cause aggregateExpressions can't be replaced by expanded + // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim + // unnecessary alias of GetStructField here. + case a: Aggregate => + val planForResolve = a.child match { + // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of + // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + case appendColumns: AppendColumns => appendColumns + case _ => a + } + + val resolvedGroupingExprs = a.groupingExpressions + .map(resolveExpressionTopDown(_, planForResolve, trimAlias = true)) + .map(trimTopLevelGetStructFieldAlias) + + val resolvedAggExprs = a.aggregateExpressions + .map(resolveExpressionTopDown(_, planForResolve, trimAlias = true)) + .map(_.asInstanceOf[NamedExpression]) + + a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child) + + // SPARK-31607: Resolve Struct field in selectedGroupByExprs/groupByExprs and aggregations + // will be wrapped with alias like Alias(GetStructField, name) with different ExprId. + // This cause aggregateExpressions can't be replaced by expanded groupByExpressions in + // `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim unnecessary alias + // of GetStructField here. + case g: GroupingSets => + val resolvedSelectedExprs = g.selectedGroupByExprs + .map(_.map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(trimTopLevelGetStructFieldAlias)) + + val resolvedGroupingExprs = g.groupByExprs + .map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(trimTopLevelGetStructFieldAlias) + + val resolvedAggExprs = g.aggregations + .map(resolveExpressionTopDown(_, g, trimAlias = true)) + .map(_.asInstanceOf[NamedExpression]) + + g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs) case o: OverwriteByExpression if !o.outputResolved => // do not resolve expression attributes until the query attributes are resolved against the @@ -1525,6 +1588,16 @@ class Analyzer( AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } + // This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level + // GetStructField Alias. Since these expression are not NamedExpression originally, + // we are safe to trim top-level GetStructField Alias. + def trimTopLevelGetStructFieldAlias(e: Expression): Expression = { + e match { + case Alias(s: GetStructField, _) => s + case other => other + } + } + /** * Build a project list for Project/Aggregate and expand the star if possible */ diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 26a44a85841e..c3dd8c0f9f71 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -110,6 +110,6 @@ struct<> -- !query SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") -- !query schema -struct +struct -- !query output gamma 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f7a904169d6c..1b4a0df1913b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3496,6 +3496,108 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkIfSeedExistsInExplain(df2) } + test("SPARK-31670: Trim unnecessary Struct field alias in Aggregate/GroupingSets") { + withTempView("t") { + sql( + """ + |CREATE TEMPORARY VIEW t(a, b, c) AS + |SELECT * FROM VALUES + |('A', 1, NAMED_STRUCT('row_id', 1, 'json_string', '{"i": 1}')), + |('A', 2, NAMED_STRUCT('row_id', 2, 'json_string', '{"i": 1}')), + |('A', 2, NAMED_STRUCT('row_id', 2, 'json_string', '{"i": 2}')), + |('B', 1, NAMED_STRUCT('row_id', 3, 'json_string', '{"i": 1}')), + |('C', 3, NAMED_STRUCT('row_id', 4, 'json_string', '{"i": 1}')) + """.stripMargin) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |""".stripMargin), + Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: + Row("B", "{\"i\": 1}", 1) :: Row("C", "{\"i\": 1}", 3) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |WITH CUBE + |""".stripMargin), + Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: Row("A", null, 5) :: + Row("B", "{\"i\": 1}", 1) :: Row("B", null, 1) :: + Row("C", "{\"i\": 1}", 3) :: Row("C", null, 3) :: + Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Row(null, null, 9) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, get_json_object(c.json_string, '$.i'), SUM(b) + |FROM t + |GROUP BY a, get_json_object(c.json_string, '$.i') + |WITH CUBE + |""".stripMargin), + Row("A", "1", 3) :: Row("A", "2", 2) :: Row("A", null, 5) :: + Row("B", "1", 1) :: Row("B", null, 1) :: + Row("C", "1", 3) :: Row("C", null, 3) :: + Row(null, "1", 7) :: Row(null, "2", 2) :: Row(null, null, 9) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string AS json_string, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |WITH CUBE + |""".stripMargin), + Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: + Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) :: + Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) :: + Row(null, null, 9) :: Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string as js, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |WITH CUBE + |""".stripMargin), + Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: + Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) :: + Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) :: + Row(null, null, 9) :: Row(null, "{\"i\": 1}", 7) :: Row(null, "{\"i\": 2}", 2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string as js, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |WITH ROLLUP + |""".stripMargin), + Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: + Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) :: + Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) :: + Row(null, null, 9) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, c.json_string, SUM(b) + |FROM t + |GROUP BY a, c.json_string + |GROUPING sets((a),(a, c.json_string)) + |""".stripMargin), + Row("A", null, 5) :: Row("A", "{\"i\": 1}", 3) :: Row("A", "{\"i\": 2}", 2) :: + Row("B", null, 1) :: Row("B", "{\"i\": 1}", 1) :: + Row("C", null, 3) :: Row("C", "{\"i\": 1}", 3) :: Nil) + } + } + test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") { checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1))) checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"),