From c77fc72774d6fd4ee30b497b070636f061502502 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 7 Mar 2017 21:28:47 +0900 Subject: [PATCH 01/23] Resolve aliases in GROUP-BY --- .../sql/catalyst/analysis/Analyzer.scala | 19 +++++++++++++++++++ .../resources/sql-tests/inputs/group-by.sql | 3 +++ .../sql-tests/results/group-by.sql.out | 13 ++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcadbbc90f43..fc286b562a4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -148,6 +148,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + ResolveGroupByAlias :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: @@ -1636,6 +1637,24 @@ class Analyzer( } } + /** + * Resolve aliases in a GROUP BY clause. + */ + object ResolveGroupByAlias extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute => + aggs.find(ne => resolver(ne.name, u.name)).map { + case alias @ Alias(e, _) => e + case e => e + }.getOrElse(u) + case e => e + }) + } + } + /** * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] * operator under [[Project]]. diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed4315300..25a9425a8d2c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,6 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0..07abd7cb810b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 16 -- !query 0 @@ -139,3 +139,14 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0f..2ce4b0485aeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2606,4 +2606,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("SPARK-14471 Aliases in SELECT could be used in GROUP BY") { + Seq(("a", "a", 0), ("b", "a", 1), ("a", "a", 2)).toDF("k1", "k2", "v") + .createOrReplaceTempView("t") + checkAnswer( + sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY k2, k1"), + Row("a", "a", 2) :: Row("b", "a", 1) :: Nil) + } } From e2c3e5fbb76021c05c8bc9005bc0322f4136e7e3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 10 Mar 2017 20:17:20 +0900 Subject: [PATCH 02/23] Apply comments --- .../sql/catalyst/analysis/Analyzer.scala | 37 +++++++------------ .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +++-- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc286b562a4c..a0ff93b5c424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -148,7 +148,6 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: - ResolveGroupByAlias :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: @@ -845,11 +844,21 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + // We need to replace unresolved attributes with the resolved ones that this plan `q` holds + // as expressions. For example, in `SELECT a AS a1, a1 + 1 AS b`, it replaces the unresolved + // `a1` of `a1 + 1 AS b` with the resolved `a1` of `a AS a1`. + val resolvedExprs = q.expressions.filter { + case ne: NamedExpression if ne.resolved => true + case _ => false + }.map(_.asInstanceOf[NamedExpression]) + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { + q.resolveChildren(nameParts, resolver).getOrElse { + resolvedExprs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + } + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -1637,24 +1646,6 @@ class Analyzer( } } - /** - * Resolve aliases in a GROUP BY clause. - */ - object ResolveGroupByAlias extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case agg @ Aggregate(groups, aggs, child) - if child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => - agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => - aggs.find(ne => resolver(ne.name, u.name)).map { - case alias @ Alias(e, _) => e - case e => e - }.getOrElse(u) - case e => e - }) - } - } - /** * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] * operator under [[Project]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2ce4b0485aeb..8ba8f1d33223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2608,10 +2608,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14471 Aliases in SELECT could be used in GROUP BY") { - Seq(("a", "a", 0), ("b", "a", 1), ("a", "a", 2)).toDF("k1", "k2", "v") + Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") .createOrReplaceTempView("t") checkAnswer( - sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY k2, k1"), - Row("a", "a", 2) :: Row("b", "a", 1) :: Nil) + sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), + Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) + checkAnswer( + sql("SELECT k1 AS key1, key1 + 1 AS key2, COUNT(1) FROM t GROUP BY key1, key2"), + Row(1, 2, 2) :: Row(2, 3, 1) :: Nil) } } From 9332c0777cf4ef1461ffb8a2da83739f70be2908 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 15 Mar 2017 15:56:32 +0900 Subject: [PATCH 03/23] Brush up code --- .../sql/catalyst/analysis/Analyzer.scala | 26 ++++++++++--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 3 --- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a0ff93b5c424..49570a0cf5fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -838,27 +838,29 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + // If grouping keys have unresolved expressions, we need to replace them with resolved one + // in SELECT clauses. + case agg @ Aggregate(groups, aggs, child) + if child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute => + aggs.find(ne => resolver(ne.name, u.name)).map { + case alias @ Alias(e, _) => e + case e => e + }.getOrElse(u) + case e => e + }) + // Skips plan which contains deserializer expressions, as they should be resolved by another // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - // We need to replace unresolved attributes with the resolved ones that this plan `q` holds - // as expressions. For example, in `SELECT a AS a1, a1 + 1 AS b`, it replaces the unresolved - // `a1` of `a1 + 1 AS b` with the resolved `a1` of `a AS a1`. - val resolvedExprs = q.expressions.filter { - case ne: NamedExpression if ne.resolved => true - case _ => false - }.map(_.asInstanceOf[NamedExpression]) q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { - q.resolveChildren(nameParts, resolver).getOrElse { - resolvedExprs.find(ne => resolver(ne.name, u.name)).getOrElse(u) - } - } + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8ba8f1d33223..781ed907c9f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2613,8 +2613,5 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) - checkAnswer( - sql("SELECT k1 AS key1, key1 + 1 AS key2, COUNT(1) FROM t GROUP BY key1, key2"), - Row(1, 2, 2) :: Row(2, 3, 1) :: Nil) } } From 55ae8aa2a188b7c32dff5e359fc1103b49b9c75e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Mar 2017 02:33:22 +0900 Subject: [PATCH 04/23] Fix test errors --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d2ebca5a83dd..3e88be321c2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -164,7 +164,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "nested aggregate functions", - testRelation.groupBy('a)( + testRelation.groupBy()( AggregateExpression( Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)), Complete, From 172482672af7c5f2890ad4d3f7d5a9f4aed02292 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Mar 2017 11:54:50 +0900 Subject: [PATCH 05/23] Add an option --- .../sql/catalyst/analysis/Analyzer.scala | 3 ++- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++++++++----- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 49570a0cf5fb..fdd6c3368c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -841,7 +841,8 @@ class Analyzer( // If grouping keys have unresolved expressions, we need to replace them with resolved one // in SELECT clauses. case agg @ Aggregate(groups, aggs, child) - if child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => + if conf.groupByAliasesEnabled && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => agg.copy(groupingExpressions = groups.map { case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9f..0486cee35093 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -421,6 +421,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES_ENABLED = buildConf("spark.sql.groupByAliasesEnabled") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByAliasesEnabled: Boolean = getConf(GROUP_BY_ALIASES_ENABLED) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 781ed907c9f6..cd804b361c64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2607,11 +2607,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-14471 Aliases in SELECT could be used in GROUP BY") { - Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") - .createOrReplaceTempView("t") - checkAnswer( - sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), - Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) + test("SPARK-14471 When groupByAliasesEnabled=true, aliases in SELECT could exist in GROUP BY") { + withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "true") { + Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") + .createOrReplaceTempView("t") + checkAnswer( + sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), + Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) + } + withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "false") { + Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") + .createOrReplaceTempView("t") + val errMsg = intercept[AnalysisException] { + sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2") + } + assert(errMsg.getMessage.startsWith("cannot resolve")) + } } } From cd28da01737cebbbfd59e8454a45c13af2b8aec7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Mar 2017 17:58:30 +0900 Subject: [PATCH 06/23] Tests mixed cases: group-by ordinals and aliases --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cd804b361c64..2321cca8ff36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2614,6 +2614,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) + // Check mixed cases: group-by ordinals and aliases + checkAnswer( + sql("SELECT k1, k2 AS key2, SUM(v) FROM t GROUP BY key2, 1"), + Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) } withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "false") { Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") From 90df0b4d90430b440cf1e6b8f50c869458e0f02f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 18 Mar 2017 11:04:09 +0900 Subject: [PATCH 07/23] Apply comments --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2321cca8ff36..b78007dba295 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2608,7 +2608,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14471 When groupByAliasesEnabled=true, aliases in SELECT could exist in GROUP BY") { - withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "true") { + withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "true", + SQLConf.GROUP_BY_ORDINAL.key -> "true") { Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") .createOrReplaceTempView("t") checkAnswer( From 95e43614835c96da53e480ffa63147cf8c3373de Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Mar 2017 11:02:26 +0900 Subject: [PATCH 08/23] Apply reviews --- .../sql-tests/inputs/group-by-ordinal.sql | 4 +++ .../resources/sql-tests/inputs/group-by.sql | 7 ++++ .../results/group-by-ordinal.sql.out | 34 ++++++++++++++---- .../sql-tests/results/group-by.sql.out | 35 +++++++++++++++++-- 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9..3ee351f7f255 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,6 +49,10 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; +-- mixed cases: group-by ordinals and aliases +explain select a, a AS k, count(b) from data group by k, 1; +select a, a AS k, count(b) from data group by k, 1; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 25a9425a8d2c..355af8b6facf 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -37,4 +37,11 @@ FROM testData; SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Aliases in SELECT could be used in GROUP BY +EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; + +-- turn off group by aliases +set spark.sql.groupByAliasesEnabled=false; + +-- Check analysis exceptions SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index d03681d0ea59..3fab1a261a33 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 21 -- !query 0 @@ -173,16 +173,38 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +explain select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +== Physical Plan == +*HashAggregate(keys=[a#x], functions=[count(1)]) ++- Exchange hashpartitioning(a#x, 200) + +- *HashAggregate(keys=[a#x], functions=[partial_count(1)]) + +- LocalTableScan [a#x] -- !query 18 -select sum(b) from data group by -1 +select a, a AS k, count(b) from data group by k, 1 -- !query 18 schema -struct +struct -- !query 18 output +1 1 2 +2 2 2 +3 3 2 + + +-- !query 19 +set spark.sql.groupByOrdinal=false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.groupByOrdinal false + + +-- !query 20 +select sum(b) from data group by -1 +-- !query 20 schema +struct +-- !query 20 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 07abd7cb810b..5c7881956e46 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 19 -- !query 0 @@ -142,11 +142,40 @@ struct -- !query 15 -SELECT a AS k, COUNT(b) FROM testData GROUP BY k +EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k -- !query 15 schema -struct +struct -- !query 15 output +== Physical Plan == +*HashAggregate(keys=[a#x], functions=[count(b#x)]) ++- Exchange hashpartitioning(a#x, 200) + +- *HashAggregate(keys=[a#x], functions=[partial_count(b#x)]) + +- LocalTableScan [a#x, b#x] + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 16 schema +struct +-- !query 16 output 1 2 2 2 3 2 NULL 1 + + +-- !query 17 +set spark.sql.groupByAliasesEnabled=false +-- !query 17 schema +struct +-- !query 17 output +spark.sql.groupByAliasesEnabled false + + +-- !query 18 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From db5979f0b002e23f3ba727c4d8e4377276a3f25d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Mar 2017 12:51:14 +0900 Subject: [PATCH 09/23] Apply more comments --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fdd6c3368c7b..d2e72f930f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -846,7 +846,7 @@ class Analyzer( agg.copy(groupingExpressions = groups.map { case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).map { - case alias @ Alias(e, _) => e + case Alias(e, _) => e case e => e }.getOrElse(u) case e => e From a594d2b90b6a81376667b79e32072e147d1863eb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Mar 2017 15:47:27 +0900 Subject: [PATCH 10/23] Add more tests --- .../resources/sql-tests/inputs/group-by.sql | 5 ++++ .../sql-tests/results/group-by.sql.out | 28 +++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 355af8b6facf..43e50e1b3e00 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -40,6 +40,11 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k; SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + -- turn off group by aliases set spark.sql.groupByAliasesEnabled=false; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 5c7881956e46..a1fa2d072504 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 21 -- !query 0 @@ -165,17 +165,35 @@ NULL 1 -- !query 17 -set spark.sql.groupByAliasesEnabled=false +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) -- !query 17 schema -struct +struct<> -- !query 17 output -spark.sql.groupByAliasesEnabled false + -- !query 18 -SELECT a AS k, COUNT(b) FROM testData GROUP BY k +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a -- !query 18 schema struct<> -- !query 18 output org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 19 +set spark.sql.groupByAliasesEnabled=false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.groupByAliasesEnabled false + + +-- !query 20 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From bba85214f3f232efeadd298375fdc97995fbca84 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 01:33:04 +0900 Subject: [PATCH 11/23] Add a new rule for group-by aliases --- .../sql/catalyst/analysis/Analyzer.scala | 54 ++++++++++++++----- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d2e72f930f34..7f891404709e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -162,6 +162,10 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), + Batch("ResolveAggAliasInGroupBy", Once, + ResolveAggAliasInGroupBy), + Batch("ResolveTimeZone", Once, + ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -838,20 +842,6 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } - // If grouping keys have unresolved expressions, we need to replace them with resolved one - // in SELECT clauses. - case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliasesEnabled && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => - agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => - aggs.find(ne => resolver(ne.name, u.name)).map { - case Alias(e, _) => e - case e => e - }.getOrElse(u) - case e => e - }) - // Skips plan which contains deserializer expressions, as they should be resolved by another // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan @@ -2380,6 +2370,42 @@ class Analyzer( } } } + + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliasesEnabled && child.resolved && groups.exists(!_.resolved) => + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute => + aggs.find(ne => resolver(ne.name, u.name)).map { + case Alias(e, _) => e + case e => e + }.getOrElse(u) + case e => e + }) + } + } + + /** + * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local + * time zone. + */ + object ResolveTimeZone extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + } } /** From 6c3c5fa7c49950cea0ca975f81815adfc45b574e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 01:35:41 +0900 Subject: [PATCH 12/23] Remove inflex notations --- .../sql/catalyst/analysis/Analyzer.scala | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f891404709e..4831aeffb352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -176,7 +176,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -204,7 +204,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -246,7 +246,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -618,7 +618,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -790,7 +790,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -964,7 +964,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1009,7 +1009,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1133,7 +1133,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1472,7 +1472,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1487,7 +1487,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1513,7 +1513,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1685,7 +1685,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1743,7 +1743,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2060,7 +2060,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2105,7 +2105,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2170,7 +2170,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2235,7 +2235,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2321,7 +2321,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2355,7 +2355,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2445,7 +2445,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2513,7 +2513,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = From 110ec5ee4987d29dc9de523a0ecdb74832883c66 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 08:34:00 +0900 Subject: [PATCH 13/23] Remove unnecessary code --- .../org/apache/spark/sql/SQLQuerySuite.scala | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b78007dba295..0dd9296a3f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2606,27 +2606,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } - - test("SPARK-14471 When groupByAliasesEnabled=true, aliases in SELECT could exist in GROUP BY") { - withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "true", - SQLConf.GROUP_BY_ORDINAL.key -> "true") { - Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") - .createOrReplaceTempView("t") - checkAnswer( - sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2"), - Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) - // Check mixed cases: group-by ordinals and aliases - checkAnswer( - sql("SELECT k1, k2 AS key2, SUM(v) FROM t GROUP BY key2, 1"), - Row(1, "a", 2) :: Row(2, "a", 1) :: Nil) - } - withSQLConf(SQLConf.GROUP_BY_ALIASES_ENABLED.key -> "false") { - Seq((1, "a", 0), (2, "a", 1), (1, "a", 2)).toDF("k1", "k2", "v") - .createOrReplaceTempView("t") - val errMsg = intercept[AnalysisException] { - sql("SELECT k1 AS key1, k2 AS key2, SUM(v) FROM t GROUP BY key1, key2") - } - assert(errMsg.getMessage.startsWith("cannot resolve")) - } - } } From f3a31af6d4f6e7b84f1474b2a6351746a405ae56 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 12:47:31 +0900 Subject: [PATCH 14/23] Apply review comments --- .../sql/catalyst/analysis/Analyzer.scala | 8 +--- .../apache/spark/sql/internal/SQLConf.scala | 4 +- .../analysis/AnalysisErrorSuite.scala | 2 +- .../sql-tests/inputs/group-by-ordinal.sql | 1 - .../resources/sql-tests/inputs/group-by.sql | 3 +- .../results/group-by-ordinal.sql.out | 30 ++++-------- .../sql-tests/results/group-by.sql.out | 46 +++++++------------ 7 files changed, 32 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4831aeffb352..c2547cd76f8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2378,13 +2378,9 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliasesEnabled && child.resolved && groups.exists(!_.resolved) => + if conf.groupByAliases && child.resolved && groups.exists(!_.resolved) => agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => - aggs.find(ne => resolver(ne.name, u.name)).map { - case Alias(e, _) => e - case e => e - }.getOrElse(u) + case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) case e => e }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0486cee35093..b24419a41edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -421,7 +421,7 @@ object SQLConf { .booleanConf .createWithDefault(true) - val GROUP_BY_ALIASES_ENABLED = buildConf("spark.sql.groupByAliasesEnabled") + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") .doc("When true, aliases in a select list can be used in group by clauses. When false, " + "an analysis exception is thrown in the case.") .booleanConf @@ -1009,7 +1009,7 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) - def groupByAliasesEnabled: Boolean = getConf(GROUP_BY_ALIASES_ENABLED) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 3e88be321c2e..d2ebca5a83dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -164,7 +164,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "nested aggregate functions", - testRelation.groupBy()( + testRelation.groupBy('a)( AggregateExpression( Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)), Complete, diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 3ee351f7f255..6566338f3d4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -50,7 +50,6 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; select count(a), a from (select 1 as a) tmp group by 2 having a > 0; -- mixed cases: group-by ordinals and aliases -explain select a, a AS k, count(b) from data group by k, 1; select a, a AS k, count(b) from data group by k, 1; -- turn of group by ordinal diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 43e50e1b3e00..317660e0a674 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -37,7 +37,6 @@ FROM testData; SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Aliases in SELECT could be used in GROUP BY -EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k; SELECT a AS k, COUNT(b) FROM testData GROUP BY k; -- Test data. @@ -46,7 +45,7 @@ CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM V SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; -- turn off group by aliases -set spark.sql.groupByAliasesEnabled=false; +set spark.sql.groupByAliases=false; -- Check analysis exceptions SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 3fab1a261a33..9ecbe19078dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 21 +-- Number of queries: 20 -- !query 0 @@ -173,38 +173,26 @@ struct -- !query 17 -explain select a, a AS k, count(b) from data group by k, 1 --- !query 17 schema -struct --- !query 17 output -== Physical Plan == -*HashAggregate(keys=[a#x], functions=[count(1)]) -+- Exchange hashpartitioning(a#x, 200) - +- *HashAggregate(keys=[a#x], functions=[partial_count(1)]) - +- LocalTableScan [a#x] - - --- !query 18 select a, a AS k, count(b) from data group by k, 1 --- !query 18 schema +-- !query 17 schema struct --- !query 18 output +-- !query 17 output 1 1 2 2 2 2 3 3 2 --- !query 19 +-- !query 18 set spark.sql.groupByOrdinal=false --- !query 19 schema +-- !query 18 schema struct --- !query 19 output +-- !query 18 output spark.sql.groupByOrdinal false --- !query 20 +-- !query 19 select sum(b) from data group by -1 --- !query 20 schema +-- !query 19 schema struct --- !query 20 output +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index a1fa2d072504..f2892ff73a88 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 21 +-- Number of queries: 20 -- !query 0 @@ -142,58 +142,46 @@ struct -- !query 15 -EXPLAIN SELECT a AS k, COUNT(b) FROM testData GROUP BY k --- !query 15 schema -struct --- !query 15 output -== Physical Plan == -*HashAggregate(keys=[a#x], functions=[count(b#x)]) -+- Exchange hashpartitioning(a#x, 200) - +- *HashAggregate(keys=[a#x], functions=[partial_count(b#x)]) - +- LocalTableScan [a#x, b#x] - - --- !query 16 SELECT a AS k, COUNT(b) FROM testData GROUP BY k --- !query 16 schema +-- !query 15 schema struct --- !query 16 output +-- !query 15 output 1 2 2 2 3 2 NULL 1 --- !query 17 +-- !query 16 CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES (1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) --- !query 17 schema +-- !query 16 schema struct<> --- !query 17 output +-- !query 16 output --- !query 18 +-- !query 17 SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a --- !query 18 schema +-- !query 17 schema struct<> --- !query 18 output +-- !query 17 output org.apache.spark.sql.AnalysisException expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; --- !query 19 -set spark.sql.groupByAliasesEnabled=false --- !query 19 schema +-- !query 18 +set spark.sql.groupByAliases=false +-- !query 18 schema struct --- !query 19 output -spark.sql.groupByAliasesEnabled false +-- !query 18 output +spark.sql.groupByAliases false --- !query 20 +-- !query 19 SELECT a AS k, COUNT(b) FROM testData GROUP BY k --- !query 20 schema +-- !query 19 schema struct<> --- !query 20 output +-- !query 19 output org.apache.spark.sql.AnalysisException cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From 620341ae8d33f25b6186efa54d25733d2d313607 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 15:37:54 +0900 Subject: [PATCH 15/23] Throw an exception when aggregate functions exist in GROUP BY --- .../sql/catalyst/analysis/Analyzer.scala | 11 ++++++- .../resources/sql-tests/inputs/group-by.sql | 3 ++ .../sql-tests/results/group-by.sql.out | 31 ++++++++++++------- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2547cd76f8a..5ca6db257074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2380,7 +2380,16 @@ class Analyzer( case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && groups.exists(!_.resolved) => agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case u: UnresolvedAttribute => + val resolvedAgg = aggs.find(ne => resolver(ne.name, u.name)) + // Check if no aggregate function exists in GROUP BY + resolvedAgg.foreach { case Alias(expr, _) => + if (expr.isInstanceOf[AggregateExpression]) { + throw new AnalysisException( + s"Aggregate function `$expr` is not allowed in GROUP BY") + } + } + resolvedAgg.getOrElse(u) case e => e }) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 317660e0a674..40e3592a8824 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -39,6 +39,9 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS -- Aliases in SELECT could be used in GROUP BY SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + -- Test data. CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES (1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f2892ff73a88..314cf71bd743 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 21 -- !query 0 @@ -153,35 +153,44 @@ NULL 1 -- !query 16 -CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES -(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +SELECT COUNT(b) AS k FROM testData GROUP BY k -- !query 16 schema struct<> -- !query 16 output - +org.apache.spark.sql.AnalysisException +Aggregate function `count(b#x)` is not allowed in GROUP BY; -- !query 17 -SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) -- !query 17 schema struct<> -- !query 17 output + + + +-- !query 18 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 18 schema +struct<> +-- !query 18 output org.apache.spark.sql.AnalysisException expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; --- !query 18 +-- !query 19 set spark.sql.groupByAliases=false --- !query 18 schema +-- !query 19 schema struct --- !query 18 output +-- !query 19 output spark.sql.groupByAliases false --- !query 19 +-- !query 20 SELECT a AS k, COUNT(b) FROM testData GROUP BY k --- !query 19 schema +-- !query 20 schema struct<> --- !query 19 output +-- !query 20 output org.apache.spark.sql.AnalysisException cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From 658cf8321388e5d285cd64cb65f9861d620feca8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Apr 2017 19:28:01 +0900 Subject: [PATCH 16/23] Revert the location of ResolveAggAliasInGroupBy and add a test --- .../sql/catalyst/analysis/Analyzer.scala | 52 +++++++++---------- .../resources/sql-tests/inputs/group-by.sql | 1 + .../sql-tests/results/group-by.sql.out | 39 ++++++++------ 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5ca6db257074..92417b8b1a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -162,8 +163,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveAggAliasInGroupBy", Once, - ResolveAggAliasInGroupBy), Batch("ResolveTimeZone", Once, ResolveTimeZone), Batch("Subquery", Once, @@ -1000,6 +999,31 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute => + val resolvedAgg = aggs.find(ne => resolver(ne.name, u.name)) + // Check if no aggregate function exists in GROUP BY + resolvedAgg.foreach { + case Alias(e, _) if ResolveAggregateFunctions.containsAggregate(e) => + throw new AnalysisException( + s"Aggregate function `$e` is not allowed in GROUP BY") + case _ => + } + resolvedAgg.getOrElse(u) + case e => e + }) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -2371,30 +2395,6 @@ class Analyzer( } } - /** - * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. - */ - object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliases && child.resolved && groups.exists(!_.resolved) => - agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => - val resolvedAgg = aggs.find(ne => resolver(ne.name, u.name)) - // Check if no aggregate function exists in GROUP BY - resolvedAgg.foreach { case Alias(expr, _) => - if (expr.isInstanceOf[AggregateExpression]) { - throw new AnalysisException( - s"Aggregate function `$expr` is not allowed in GROUP BY") - } - } - resolvedAgg.getOrElse(u) - case e => e - }) - } - } - /** * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local * time zone. diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 40e3592a8824..a7994f3beaff 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -38,6 +38,7 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS -- Aliases in SELECT could be used in GROUP BY SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; -- Aggregate functions cannot be used in GROUP BY SELECT COUNT(b) AS k FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 314cf71bd743..7274af347278 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 21 +-- Number of queries: 22 -- !query 0 @@ -153,44 +153,53 @@ NULL 1 -- !query 16 -SELECT COUNT(b) AS k FROM testData GROUP BY k +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 -- !query 16 schema -struct<> +struct -- !query 16 output -org.apache.spark.sql.AnalysisException -Aggregate function `count(b#x)` is not allowed in GROUP BY; +2 2 +3 2 -- !query 17 -CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES -(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +SELECT COUNT(b) AS k FROM testData GROUP BY k -- !query 17 schema struct<> -- !query 17 output - +org.apache.spark.sql.AnalysisException +Aggregate function `count(b#x)` is not allowed in GROUP BY; -- !query 18 -SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) -- !query 18 schema struct<> -- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output org.apache.spark.sql.AnalysisException expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; --- !query 19 +-- !query 20 set spark.sql.groupByAliases=false --- !query 19 schema +-- !query 20 schema struct --- !query 19 output +-- !query 20 output spark.sql.groupByAliases false --- !query 20 +-- !query 21 SELECT a AS k, COUNT(b) FROM testData GROUP BY k --- !query 20 schema +-- !query 21 schema struct<> --- !query 20 output +-- !query 21 output org.apache.spark.sql.AnalysisException cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From dc6ca684ced9ed5abd130e7b722fdce966815cf0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 21 Apr 2017 11:08:13 +0900 Subject: [PATCH 17/23] Remove analysis checks --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 11 +---------- .../test/resources/sql-tests/results/group-by.sql.out | 2 +- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92417b8b1a9f..2dc3a97139e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1009,16 +1009,7 @@ class Analyzer( if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => - val resolvedAgg = aggs.find(ne => resolver(ne.name, u.name)) - // Check if no aggregate function exists in GROUP BY - resolvedAgg.foreach { - case Alias(e, _) if ResolveAggregateFunctions.containsAggregate(e) => - throw new AnalysisException( - s"Aggregate function `$e` is not allowed in GROUP BY") - case _ => - } - resolvedAgg.getOrElse(u) + case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) case e => e }) } diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 7274af347278..6bf9dff883c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -167,7 +167,7 @@ SELECT COUNT(b) AS k FROM testData GROUP BY k struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Aggregate function `count(b#x)` is not allowed in GROUP BY; +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); -- !query 18 From 7b32f46b1dd83007f066ebcc4dc92a48da6ca89a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 21 Apr 2017 11:16:08 +0900 Subject: [PATCH 18/23] Fix some issues --- .../sql/catalyst/analysis/Analyzer.scala | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2dc3a97139e3..0df131f72b29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -163,8 +163,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -2385,23 +2383,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - // Casts could be added in the subquery plan through the rule TypeCoercion while coercing - // the types between the value expression and list query expression of IN expression. - // We need to subject the subquery plan through ResolveTimeZone again to setup timezone - // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) - } - } } /** From 1340862bd9e21c4e3102a9e3fc8441fa68d8383f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 26 Apr 2017 23:30:05 +0900 Subject: [PATCH 19/23] Move some rules into postHocResolutionRules --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++++------ .../sql/internal/BaseSessionStateBuilder.scala | 9 +++++---- .../spark/sql/hive/HiveSessionStateBuilder.scala | 15 ++++++++------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0df131f72b29..5745f7ba4919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -112,7 +112,13 @@ class Analyzer( * in an individual batch. This batch is to run right after the normal resolution batch and * execute its rules in one pass. */ - val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + def postHocResolutionRules: Seq[Rule[LogicalPlan]] = + ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: + ResolveMissingReferences :: + ResolveSubquery :: + ResolveAggregateFunctions :: + Nil lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, @@ -135,9 +141,6 @@ class Analyzer( ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: - ResolveOrdinalInOrderByAndGroupBy :: - ResolveAggAliasInGroupBy :: - ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: @@ -148,13 +151,12 @@ class Analyzer( ResolveNaturalAndUsingJoin :: ExtractWindowExpressions :: GlobalAggregates :: - ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), - Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), + Batch("Post-Hoc Resolution", fixedPoint, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), Batch("Nondeterministic", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2a801d87b12e..f8587ea36651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -161,10 +161,11 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - customPostHocResolutionRules + super.postHocResolutionRules ++ ( + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules) override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e16c9e46b772..770cb306e630 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -74,13 +74,14 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(session) +: - RelationConversions(conf, catalog) +: - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - HiveAnalysis +: - customPostHocResolutionRules + super.postHocResolutionRules ++ ( + new DetermineTableStats(session) +: + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules) override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: From 65f6e7cd196c50e01f7f0542eee153ba214b28b4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 27 Apr 2017 08:00:23 +0900 Subject: [PATCH 20/23] Revert "Move some rules into postHocResolutionRules" This reverts commit 1340862bd9e21c4e3102a9e3fc8441fa68d8383f. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++-------- .../sql/internal/BaseSessionStateBuilder.scala | 9 ++++----- .../spark/sql/hive/HiveSessionStateBuilder.scala | 15 +++++++-------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5745f7ba4919..0df131f72b29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -112,13 +112,7 @@ class Analyzer( * in an individual batch. This batch is to run right after the normal resolution batch and * execute its rules in one pass. */ - def postHocResolutionRules: Seq[Rule[LogicalPlan]] = - ResolveOrdinalInOrderByAndGroupBy :: - ResolveAggAliasInGroupBy :: - ResolveMissingReferences :: - ResolveSubquery :: - ResolveAggregateFunctions :: - Nil + val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, @@ -141,6 +135,9 @@ class Analyzer( ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: + ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: @@ -151,12 +148,13 @@ class Analyzer( ResolveNaturalAndUsingJoin :: ExtractWindowExpressions :: GlobalAggregates :: + ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), - Batch("Post-Hoc Resolution", fixedPoint, postHocResolutionRules: _*), + Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), Batch("Nondeterministic", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f8587ea36651..2a801d87b12e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -161,11 +161,10 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - super.postHocResolutionRules ++ ( - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - customPostHocResolutionRules) + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 770cb306e630..e16c9e46b772 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -74,14 +74,13 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - super.postHocResolutionRules ++ ( - new DetermineTableStats(session) +: - RelationConversions(conf, catalog) +: - PreprocessTableCreation(session) +: - PreprocessTableInsertion(conf) +: - DataSourceAnalysis(conf) +: - HiveAnalysis +: - customPostHocResolutionRules) + new DetermineTableStats(session) +: + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: From 86402b08cd00a952e206268fc2ac50a5e427a6b1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 27 Apr 2017 16:25:51 +0900 Subject: [PATCH 21/23] Apply review comments --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0df131f72b29..2e83c9c52a4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -999,13 +999,14 @@ class Analyzer( /** * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. */ object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => agg.copy(groupingExpressions = groups.map { case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) case e => e From 0ae48d825678adda2d110be816f12bb4830aaf23 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 27 Apr 2017 17:09:34 +0900 Subject: [PATCH 22/23] Put strick checker --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2e83c9c52a4f..ca42d424d378 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1007,8 +1007,12 @@ class Analyzer( case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + // This is a strict check though, we put this to apply the rule only in alias expressions + def checkIfChildOutputHasNo(attrName: String): Boolean = + !child.output.exists(a => resolver(a.name, attrName)) agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case u: UnresolvedAttribute if checkIfChildOutputHasNo(u.name) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) case e => e }) } From d3071fa48d329dd39756d42f40207a212cf7b712 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 28 Apr 2017 01:31:54 +0900 Subject: [PATCH 23/23] Rename a function --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca42d424d378..72e7d5dd3638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1008,10 +1008,10 @@ class Analyzer( if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedAttribute]) => // This is a strict check though, we put this to apply the rule only in alias expressions - def checkIfChildOutputHasNo(attrName: String): Boolean = + def notResolvableByChild(attrName: String): Boolean = !child.output.exists(a => resolver(a.name, attrName)) agg.copy(groupingExpressions = groups.map { - case u: UnresolvedAttribute if checkIfChildOutputHasNo(u.name) => + case u: UnresolvedAttribute if notResolvableByChild(u.name) => aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) case e => e })