Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: grouping sets, rollup and cube rendering issue #1019

7 changes: 6 additions & 1 deletion sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,12 @@ def visit_label(self, *args, within_group_by=False, **kwargs):
# Flag set in the group_by_clause method. Works around missing
# equivalent to supports_simple_order_by_label for group by.
if within_group_by:
kwargs["render_label_as_label"] = args[0]
if all(
keyword not in str(args[0])
for keyword in ("GROUPING SETS", "ROLLUP", "CUBE")
):
kwargs["render_label_as_label"] = args[0]

kiraksi marked this conversation as resolved.
Show resolved Hide resolved
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)

def group_by_clause(self, select, **kw):
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
sqlalchemy_2_0_or_higher,
sqlalchemy_before_2_0,
)
from sqlalchemy.sql.functions import rollup, cube, grouping_sets
from sqlalchemy import func


def test_constraints_are_ignored(faux_conn, metadata):
Expand Down Expand Up @@ -279,3 +281,87 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata)
)
found_outer_sql = q.compile(faux_conn).string
assert found_outer_sql == expected_outer_sql


def test_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_rollup(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
rollup(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_cube(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
cube(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_multiple_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql