Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1098,9 +1098,16 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
*/
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an existing unit test suite for this? might be good to add a test case there too.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, _, _) =>
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
a.copy(groupingExpressions = newGrouping)
if (newGrouping.nonEmpty) {
a.copy(groupingExpressions = newGrouping)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
// instead replace this by single, easy to hash/sort, literal expression.
a.copy(groupingExpressions = Seq(Literal(0, IntegerType)))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)

Expand All @@ -49,6 +49,14 @@ class AggregateOptimizeSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("do not remove all grouping expressions if they are all literals") {
val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

comparePlans(optimized, correctAnswer)
}

test("Remove aliased literals") {
val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- Temporary data.
create temporary view myview as values 128, 256 as v(int_col);

-- group by should produce all input rows,
select int_col, count(*) from myview group by int_col;

-- group by should produce a single row.
select 'foo', count(*) from myview group by 1;

-- group-by should not produce any rows (whole stage code generation).
select 'foo' from myview where int_col == 0 group by 1;

-- group-by should not produce any rows (hash aggregate).
select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1;

-- group-by should not produce any rows (sort aggregate).
select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1;
51 changes: 51 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 6


-- !query 0
create temporary view myview as values 128, 256 as v(int_col)
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
select int_col, count(*) from myview group by int_col
-- !query 1 schema
struct<int_col:int,count(1):bigint>
-- !query 1 output
128 1
256 1


-- !query 2
select 'foo', count(*) from myview group by 1
-- !query 2 schema
struct<foo:string,count(1):bigint>
-- !query 2 output
foo 2


-- !query 3
select 'foo' from myview where int_col == 0 group by 1
-- !query 3 schema
struct<foo:string>
-- !query 3 output



-- !query 4
select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1
-- !query 4 schema
struct<foo:string,approx_count_distinct(int_col):bigint>
-- !query 4 output



-- !query 5
select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1
-- !query 5 schema
struct<foo:string,max(struct(int_col)):struct<int_col:int>>
-- !query 5 output