Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: grouping sets, rollup and cube rendering issue #1019

11 changes: 8 additions & 3 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,14 @@ def visit_label(self, *args, within_group_by=False, **kwargs):
# Flag set in the group_by_clause method. Works around missing
# equivalent to supports_simple_order_by_label for group by.
if within_group_by:
kwargs["render_label_as_label"] = args[0]
column_label = args[0]
sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"}
for keyword in sql_keywords:
if keyword in str(column_label):
break
else: # for/else always happens unless break gets called
kwargs["render_label_as_label"] = column_label

return super(BigQueryCompiler, self).visit_label(*args, **kwargs)

def group_by_clause(self, select, **kw):
Expand Down Expand Up @@ -395,8 +402,6 @@ def visit_not_in_op_binary(self, binary, operator, **kw):
+ ")"
)

visit_notin_op_binary = visit_not_in_op_binary # before 1.4

############################################################################

############################################################################
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sqlalchemy_2_0_or_higher,
sqlalchemy_before_2_0,
)
from sqlalchemy.sql.functions import rollup, cube, grouping_sets


def test_constraints_are_ignored(faux_conn, metadata):
Expand Down Expand Up @@ -279,3 +280,87 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata)
)
found_outer_sql = q.compile(faux_conn).string
assert found_outer_sql == expected_outer_sql


def test_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_rollup(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
rollup(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_cube(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
cube(table.c.foo, table.c.bar)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql


def test_multiple_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo)
)

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)"
)
found_sql = q.compile(faux_conn).string
assert found_sql == expected_sql