Skip to content

Commit 0a3e172

Browse files
authored
refactor: add agg_ops.StdOp, VarOp and PopVarOp for the sqlglot compiler (#2224)
1 parent bfbb2f0 commit 0a3e172

File tree

10 files changed

+220
-2
lines changed

10 files changed

+220
-2
lines changed

bigframes/core/compile/sqlglot/aggregations/op_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ def arg_checker(*args, **kwargs):
5252
def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
5353
key = op if isinstance(op, type) else type(op)
5454
if str(key) not in self._registered_ops:
55-
raise ValueError(f"{key} is already not registered")
55+
raise ValueError(f"{key} is not registered")
5656
return self._registered_ops[str(key)]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,20 @@ def _(
239239
return apply_window_if_present(sge.func("MIN", column.expr), window)
240240

241241

242+
@UNARY_OP_REGISTRATION.register(agg_ops.PopVarOp)
243+
def _(
244+
op: agg_ops.PopVarOp,
245+
column: typed_expr.TypedExpr,
246+
window: typing.Optional[window_spec.WindowSpec] = None,
247+
) -> sge.Expression:
248+
expr = column.expr
249+
if column.dtype == dtypes.BOOL_DTYPE:
250+
expr = sge.Cast(this=expr, to="INT64")
251+
252+
expr = sge.func("VAR_POP", expr)
253+
return apply_window_if_present(expr, window)
254+
255+
242256
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
243257
def _(
244258
op: agg_ops.QuantileOp,
@@ -278,6 +292,22 @@ def _(
278292
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
279293

280294

295+
@UNARY_OP_REGISTRATION.register(agg_ops.StdOp)
296+
def _(
297+
op: agg_ops.StdOp,
298+
column: typed_expr.TypedExpr,
299+
window: typing.Optional[window_spec.WindowSpec] = None,
300+
) -> sge.Expression:
301+
expr = column.expr
302+
if column.dtype == dtypes.BOOL_DTYPE:
303+
expr = sge.Cast(this=expr, to="INT64")
304+
305+
expr = sge.func("STDDEV", expr)
306+
if op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE:
307+
expr = sge.Cast(this=sge.func("FLOOR", expr), to="INT64")
308+
return apply_window_if_present(expr, window)
309+
310+
281311
@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp)
282312
def _(
283313
op: agg_ops.ShiftOp,
@@ -331,3 +361,17 @@ def _(
331361
expression=shifted,
332362
unit=sge.Identifier(this="MICROSECOND"),
333363
)
364+
365+
366+
@UNARY_OP_REGISTRATION.register(agg_ops.VarOp)
367+
def _(
368+
op: agg_ops.VarOp,
369+
column: typed_expr.TypedExpr,
370+
window: typing.Optional[window_spec.WindowSpec] = None,
371+
) -> sge.Expression:
372+
expr = column.expr
373+
if column.dtype == dtypes.BOOL_DTYPE:
374+
expr = sge.Cast(this=expr, to="INT64")
375+
376+
expr = sge.func("VAR_SAMP", expr)
377+
return apply_window_if_present(expr, window)

tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_engines_unary_aggregates(
111111
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
112112

113113

114-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
114+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
115115
@pytest.mark.parametrize(
116116
"op",
117117
[agg_ops.std_op, agg_ops.var_op, agg_ops.PopVarOp()],
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
VAR_POP(`bfcol_1`) AS `bfcol_4`,
9+
VAR_POP(CAST(`bfcol_0` AS INT64)) AS `bfcol_5`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`,
14+
`bfcol_5` AS `bool_col`
15+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE VAR_POP(`bfcol_0`) OVER () END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`duration_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_1` AS `bfcol_6`,
11+
`bfcol_0` AS `bfcol_7`,
12+
`bfcol_2` AS `bfcol_8`
13+
FROM `bfcte_0`
14+
), `bfcte_2` AS (
15+
SELECT
16+
STDDEV(`bfcol_6`) AS `bfcol_12`,
17+
STDDEV(CAST(`bfcol_7` AS INT64)) AS `bfcol_13`,
18+
CAST(FLOOR(STDDEV(`bfcol_8`)) AS INT64) AS `bfcol_14`,
19+
CAST(FLOOR(STDDEV(`bfcol_6`)) AS INT64) AS `bfcol_15`
20+
FROM `bfcte_1`
21+
)
22+
SELECT
23+
`bfcol_12` AS `int64_col`,
24+
`bfcol_13` AS `bool_col`,
25+
`bfcol_14` AS `duration_col`,
26+
`bfcol_15` AS `int64_col_w_floor`
27+
FROM `bfcte_2`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE STDDEV(`bfcol_0`) OVER () END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
VARIANCE(`bfcol_1`) AS `bfcol_4`,
9+
VARIANCE(CAST(`bfcol_0` AS INT64)) AS `bfcol_5`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`,
14+
`bfcol_5` AS `bool_col`
15+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE VARIANCE(`bfcol_0`) OVER () END AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `agg_int64`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,28 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
370370
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
371371

372372

373+
def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
374+
col_names = ["int64_col", "bool_col"]
375+
bf_df = scalar_types_df[col_names]
376+
377+
agg_ops_map = {
378+
"int64_col": agg_ops.PopVarOp().as_expr("int64_col"),
379+
"bool_col": agg_ops.PopVarOp().as_expr("bool_col"),
380+
}
381+
sql = _apply_unary_agg_ops(
382+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
383+
)
384+
snapshot.assert_match(sql, "out.sql")
385+
386+
# Window tests
387+
col_name = "int64_col"
388+
bf_df_int = scalar_types_df[[col_name]]
389+
agg_expr = agg_ops.PopVarOp().as_expr(col_name)
390+
window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),))
391+
sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64")
392+
snapshot.assert_match(sql_window, "window_out.sql")
393+
394+
373395
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
374396
col_name = "int64_col"
375397
bf_df = scalar_types_df[[col_name]]
@@ -428,6 +450,40 @@ def test_shift(scalar_types_df: bpd.DataFrame, snapshot):
428450
snapshot.assert_match(noop_sql, "noop.sql")
429451

430452

453+
def test_std(scalar_types_df: bpd.DataFrame, snapshot):
454+
col_names = ["int64_col", "bool_col", "duration_col"]
455+
bf_df = scalar_types_df[col_names]
456+
bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us")
457+
458+
# The `to_timedelta` creates a new mapping for the column id.
459+
col_names.insert(0, "rowindex")
460+
name2id = {
461+
col_name: col_id
462+
for col_name, col_id in zip(col_names, bf_df._block.expr.column_ids)
463+
}
464+
465+
agg_ops_map = {
466+
"int64_col": agg_ops.StdOp().as_expr(name2id["int64_col"]),
467+
"bool_col": agg_ops.StdOp().as_expr(name2id["bool_col"]),
468+
"duration_col": agg_ops.StdOp().as_expr(name2id["duration_col"]),
469+
"int64_col_w_floor": agg_ops.StdOp(should_floor_result=True).as_expr(
470+
name2id["int64_col"]
471+
),
472+
}
473+
sql = _apply_unary_agg_ops(
474+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
475+
)
476+
snapshot.assert_match(sql, "out.sql")
477+
478+
# Window tests
479+
col_name = "int64_col"
480+
bf_df_int = scalar_types_df[[col_name]]
481+
agg_expr = agg_ops.StdOp().as_expr(col_name)
482+
window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),))
483+
sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64")
484+
snapshot.assert_match(sql_window, "window_out.sql")
485+
486+
431487
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
432488
bf_df = scalar_types_df[["int64_col", "bool_col"]]
433489
agg_ops_map = {
@@ -468,3 +524,25 @@ def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
468524
)
469525
sql = _apply_unary_window_op(bf_df, op, window, "diff_time")
470526
snapshot.assert_match(sql, "out.sql")
527+
528+
529+
def test_var(scalar_types_df: bpd.DataFrame, snapshot):
530+
col_names = ["int64_col", "bool_col"]
531+
bf_df = scalar_types_df[col_names]
532+
533+
agg_ops_map = {
534+
"int64_col": agg_ops.VarOp().as_expr("int64_col"),
535+
"bool_col": agg_ops.VarOp().as_expr("bool_col"),
536+
}
537+
sql = _apply_unary_agg_ops(
538+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
539+
)
540+
snapshot.assert_match(sql, "out.sql")
541+
542+
# Window tests
543+
col_name = "int64_col"
544+
bf_df_int = scalar_types_df[[col_name]]
545+
agg_expr = agg_ops.VarOp().as_expr(col_name)
546+
window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),))
547+
sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64")
548+
snapshot.assert_match(sql_window, "window_out.sql")

0 commit comments

Comments
 (0)