Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support no distinct aggregate sum/min/max in single_distinct_to_group_by rule #8266

Merged
merged 7 commits into from
Nov 26, 2023

Conversation

haohuaijin
Copy link
Contributor

@haohuaijin haohuaijin commented Nov 19, 2023

Which issue does this PR close?

Closes #8123

Rationale for this change

Generate test data

./benchmarks/bench.sh data

Compare

in this pr(release mode)

❯ SELECT "RegionID", SUM("AdvEngineID"), COUNT(DISTINCT "UserID") FROM '../benchmarks/data/hits.parquet' GROUP BY "RegionID" order by "RegionID" limit 10;
+----------+--------------------------------------------------+--------------------------------------------------------+
| RegionID | SUM(../benchmarks/data/hits.parquet.AdvEngineID) | COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID) |
+----------+--------------------------------------------------+--------------------------------------------------------+
| 0        | 0                                                | 8                                                      |
| 1        | 147946                                           | 239380                                                 |
| 2        | 441662                                           | 1081016                                                |
| 3        | 39724                                            | 131195                                                 |
| 4        | 34557                                            | 79500                                                  |
| 5        | 13502                                            | 40914                                                  |
| 6        | 24338                                            | 55768                                                  |
| 7        | 28417                                            | 64989                                                  |
| 8        | 34483                                            | 65472                                                  |
| 9        | 38047                                            | 91576                                                  |
+----------+--------------------------------------------------+--------------------------------------------------------+
10 rows in set. Query took 0.935 seconds.

❯ explain SELECT "RegionID", SUM("AdvEngineID"), COUNT(DISTINCT "UserID") FROM '../benchmarks/data/hits.parquet' GROUP BY "RegionID" order by "RegionID" limit 10;
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type     | plan                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan  | Limit: skip=0, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                                                                     |
|               |   Sort: ../benchmarks/data/hits.parquet.RegionID ASC NULLS LAST, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                   |
|               |     Projection: ../benchmarks/data/hits.parquet.RegionID, SUM(alias2) AS SUM(../benchmarks/data/hits.parquet.AdvEngineID), COUNT(alias1) AS COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID)                                                                                                                                                                                                                                                                                          |
|               |       Aggregate: groupBy=[[../benchmarks/data/hits.parquet.RegionID]], aggr=[[SUM(alias2), COUNT(alias1)]]                                                                                                                                                                                                                                                                                                                                                                                  |
|               |         Aggregate: groupBy=[[../benchmarks/data/hits.parquet.RegionID, ../benchmarks/data/hits.parquet.UserID AS alias1]], aggr=[[SUM(CAST(../benchmarks/data/hits.parquet.AdvEngineID AS Int64)) AS alias2]]                                                                                                                                                                                                                                                                               |
|               |           TableScan: ../benchmarks/data/hits.parquet projection=[RegionID, UserID, AdvEngineID]                                                                                                                                                                                                                                                                                                                                                                                             |
| physical_plan | GlobalLimitExec: skip=0, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                                                           |
|               |   SortPreservingMergeExec: [RegionID@0 ASC NULLS LAST], fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                            |
|               |     SortExec: TopK(fetch=10), expr=[RegionID@0 ASC NULLS LAST]                                                                                                                                                                                                                                                                                                                                                                                                                              |
|               |       ProjectionExec: expr=[RegionID@0 as RegionID, SUM(alias2)@1 as SUM(../benchmarks/data/hits.parquet.AdvEngineID), COUNT(alias1)@2 as COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID)]                                                                                                                                                                                                                                                                                           |
|               |         AggregateExec: mode=FinalPartitioned, gby=[RegionID@0 as RegionID], aggr=[SUM(alias2), COUNT(alias1)]                                                                                                                                                                                                                                                                                                                                                                               |
|               |           CoalesceBatchesExec: target_batch_size=8192                                                                                                                                                                                                                                                                                                                                                                                                                                       |
|               |             RepartitionExec: partitioning=Hash([RegionID@0], 24), input_partitions=24                                                                                                                                                                                                                                                                                                                                                                                                       |
|               |               AggregateExec: mode=Partial, gby=[RegionID@0 as RegionID], aggr=[SUM(alias2), COUNT(alias1)]                                                                                                                                                                                                                                                                                                                                                                                  |
|               |                 AggregateExec: mode=FinalPartitioned, gby=[RegionID@0 as RegionID, alias1@1 as alias1], aggr=[alias2]                                                                                                                                                                                                                                                                                                                                                                       |
|               |                   CoalesceBatchesExec: target_batch_size=8192                                                                                                                                                                                                                                                                                                                                                                                                                               |
|               |                     RepartitionExec: partitioning=Hash([RegionID@0, alias1@1], 24), input_partitions=24                                                                                                                                                                                                                                                                                                                                                                                     |
|               |                       AggregateExec: mode=Partial, gby=[RegionID@0 as RegionID, UserID@1 as alias1], aggr=[alias2]                                                                                                                                                                                                                                                                                                                                                                          |
|               |                         ParquetExec: file_groups={24 groups: [[home/hhj/datafusion/benchmarks/data/hits.parquet:0..615832352], [home/hhj/datafusion/benchmarks/data/hits.parquet:615832352..1231664704], [home/hhj/datafusion/benchmarks/data/hits.parquet:1231664704..1847497056], [home/hhj/datafusion/benchmarks/data/hits.parquet:1847497056..2463329408], [home/hhj/datafusion/benchmarks/data/hits.parquet:2463329408..3079161760], ...]}, projection=[RegionID, UserID, AdvEngineID] |
|               |                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
2 rows in set. Query took 0.043 seconds.

in main 393e48f (release mode)

❯ SELECT "RegionID", SUM("AdvEngineID"), COUNT(DISTINCT "UserID") FROM '../benchmarks/data/hits.parquet' GROUP BY "RegionID" order by "RegionID" limit 10;
+----------+--------------------------------------------------+--------------------------------------------------------+
| RegionID | SUM(../benchmarks/data/hits.parquet.AdvEngineID) | COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID) |
+----------+--------------------------------------------------+--------------------------------------------------------+
| 0        | 0                                                | 8                                                      |
| 1        | 147946                                           | 239380                                                 |
| 2        | 441662                                           | 1081016                                                |
| 3        | 39724                                            | 131195                                                 |
| 4        | 34557                                            | 79500                                                  |
| 5        | 13502                                            | 40914                                                  |
| 6        | 24338                                            | 55768                                                  |
| 7        | 28417                                            | 64989                                                  |
| 8        | 34483                                            | 65472                                                  |
| 9        | 38047                                            | 91576                                                  |
+----------+--------------------------------------------------+--------------------------------------------------------+
10 rows in set. Query took 1.349 seconds.

❯ explain SELECT "RegionID", SUM("AdvEngineID"), COUNT(DISTINCT "UserID") FROM '../benchmarks/data/hits.parquet' GROUP BY "RegionID" order by "RegionID" limit 10;
+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type     | plan                                                                                                                                                                                                                                                                                                                                                                                                                                                                              |
+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan  | Limit: skip=0, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                                                           |
|               |   Sort: ../benchmarks/data/hits.parquet.RegionID ASC NULLS LAST, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                         |
|               |     Aggregate: groupBy=[[../benchmarks/data/hits.parquet.RegionID]], aggr=[[SUM(CAST(../benchmarks/data/hits.parquet.AdvEngineID AS Int64)), COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID)]]                                                                                                                                                                                                                                                                             |
|               |       TableScan: ../benchmarks/data/hits.parquet projection=[RegionID, UserID, AdvEngineID]                                                                                                                                                                                                                                                                                                                                                                                       |
| physical_plan | GlobalLimitExec: skip=0, fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                                                 |
|               |   SortPreservingMergeExec: [RegionID@0 ASC NULLS LAST], fetch=10                                                                                                                                                                                                                                                                                                                                                                                                                  |
|               |     SortExec: TopK(fetch=10), expr=[RegionID@0 ASC NULLS LAST]                                                                                                                                                                                                                                                                                                                                                                                                                    |
|               |       AggregateExec: mode=FinalPartitioned, gby=[RegionID@0 as RegionID], aggr=[SUM(../benchmarks/data/hits.parquet.AdvEngineID), COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID)]                                                                                                                                                                                                                                                                                         |
|               |         CoalesceBatchesExec: target_batch_size=8192                                                                                                                                                                                                                                                                                                                                                                                                                               |
|               |           RepartitionExec: partitioning=Hash([RegionID@0], 24), input_partitions=24                                                                                                                                                                                                                                                                                                                                                                                               |
|               |             AggregateExec: mode=Partial, gby=[RegionID@0 as RegionID], aggr=[SUM(../benchmarks/data/hits.parquet.AdvEngineID), COUNT(DISTINCT ../benchmarks/data/hits.parquet.UserID)]                                                                                                                                                                                                                                                                                            |
|               |               ParquetExec: file_groups={24 groups: [[home/hhj/datafusion/benchmarks/data/hits.parquet:0..615832352], [home/hhj/datafusion/benchmarks/data/hits.parquet:615832352..1231664704], [home/hhj/datafusion/benchmarks/data/hits.parquet:1231664704..1847497056], [home/hhj/datafusion/benchmarks/data/hits.parquet:1847497056..2463329408], [home/hhj/datafusion/benchmarks/data/hits.parquet:2463329408..3079161760], ...]}, projection=[RegionID, UserID, AdvEngineID] |
|               |                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   |
+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
2 rows in set. Query took 0.040 seconds.

What changes are included in this PR?

add no-distinct sum/min/max aggregate support in single_distinct_to_group_by rule

Are these changes tested?

yes, add some tests

Are there any user-facing changes?

@github-actions github-actions bot added optimizer Optimizer rules sqllogictest SQL Logic Tests (.slt) labels Nov 19, 2023
@alamb
Copy link
Contributor

alamb commented Nov 21, 2023

Thank you @haohuaijin -- I plan to review this PR tomorrow

Copy link
Contributor

@alamb alamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @haohuaijin -- this looks really neat.

I had some small improvement suggestions that are not required as well as a a few additional tests to add

I am also running a full ClickBench benchmark run to see how this branch compares. I'll post the results to this PR

Please let me know if you are willing to make changes to this PR or if I should merge it in as is.

Thanks again, this is really neat

datafusion/optimizer/src/single_distinct_to_groupby.rs Outdated Show resolved Hide resolved
datafusion/optimizer/src/single_distinct_to_groupby.rs Outdated Show resolved Hide resolved
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment here explaining why this is checking for only Sum, Min and Max

I think it is because these functions have the property that they can be used to combine themselves -- so like SUM(SUM(x)) is the same as SUM(x)

I think this transformation could be made more general by using a different combination. For example we could support COUNT by using SUM(COUNT(x))

SELECT a, COUNT(DINSTINCT b), COUNT(c)
FROM t
GROUP BY a
SELECT a, COUNT(alias1), SUM(alias2) -- <-- This is SUM, not COUNT
FROM (
  SELECT a, b as alias1, COUNT(c) as alias2
  FROM t
  GROUP BY a, b
)
GROUP BY a

We can do this as a follow on PR ( I can file a ticket)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do crazier stuff for AVG like

SELECT a, COUNT(DINSTINCT b), AVG(c)
FROM t
GROUP BY a
SELECT a, COUNT(alias1), SUM(alias2) / SUM(alias3) -- <-- This is combining partial sum / counts to compute AVG
FROM (
  SELECT a, b as alias1, SUM(c) as alias2, COUNT(c) as alias3,
  FROM t
  GROUP BY a, b
)
GROUP BY a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite Count requires us to implement Sum0. you can refer to the discussion in this link for details.

vec![
sum(col("c")),
count_distinct(col("b")),
Expr::AggregateFunction(expr::AggregateFunction::new(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@haohuaijin haohuaijin Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we only have count_distinct interface. Should we also add sum_distinct, max_distinct (maybe no sense, because the result of max_distinct is equal to max) and more?

@haohuaijin
Copy link
Contributor Author

haohuaijin commented Nov 22, 2023

Thanks for review @alamb , those suggestions are very helpful. I will apply those suggestions tonight or tomorrow.

@haohuaijin haohuaijin requested a review from alamb November 23, 2023 05:06
@alamb
Copy link
Contributor

alamb commented Nov 26, 2023

Thanks again @haohuaijin

@alamb alamb merged commit f29bcf3 into apache:main Nov 26, 2023
22 checks passed
@haohuaijin haohuaijin deleted the distinct branch November 26, 2023 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
optimizer Optimizer rules sqllogictest SQL Logic Tests (.slt)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support no distinct count/max/min/sum aggregate in single_distinct_to_group_by rule
2 participants