Skip to content

Commit 4c98c95

Browse files
authored
refactor: add struct_op and sql_scalar_op for the sqlglot compiler (#2197)
1 parent 3e6299f commit 4c98c95

File tree

7 files changed

+112
-2
lines changed

7 files changed

+112
-2
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import sqlglot as sg
1718
import sqlglot.expressions as sge
1819

1920
from bigframes import dtypes
@@ -80,6 +81,16 @@ def _(expr: TypedExpr) -> sge.Expression:
8081
return sge.BitwiseNot(this=sge.paren(expr.expr))
8182

8283

84+
@register_nary_op(ops.SqlScalarOp, pass_op=True)
85+
def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression:
86+
return sg.parse_one(
87+
op.sql_template.format(
88+
*[operand.expr.sql(dialect="bigquery") for operand in operands]
89+
),
90+
dialect="bigquery",
91+
)
92+
93+
8394
@register_unary_op(ops.isnull_op)
8495
def _(expr: TypedExpr) -> sge.Expression:
8596
return sge.Is(this=expr.expr, expression=sge.Null())

bigframes/core/compile/sqlglot/expressions/struct_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2626

27+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2728
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2829

2930

@@ -40,3 +41,13 @@ def _(expr: TypedExpr, op: ops.StructFieldOp) -> sge.Expression:
4041
this=sge.to_identifier(name, quoted=True),
4142
catalog=expr.expr,
4243
)
44+
45+
46+
@register_nary_op(ops.StructOp, pass_op=True)
47+
def _(*exprs: TypedExpr, op: ops.StructOp) -> sge.Struct:
48+
return sge.Struct(
49+
expressions=[
50+
sge.PropertyEQ(this=sge.to_identifier(col), expression=expr.expr)
51+
for col, expr in zip(op.column_names, exprs)
52+
]
53+
)

bigframes/testing/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,23 @@ def _apply_binary_op(
475475
) -> str:
476476
"""Applies a binary op to the given DataFrame and return the SQL representing
477477
the resulting DataFrame."""
478+
return _apply_nary_op(obj, op, l_arg, r_arg)
479+
480+
481+
def _apply_nary_op(
482+
obj: bpd.DataFrame,
483+
op: Union[ops.BinaryOp, ops.NaryOp],
484+
*args: Union[str, ex.Expression],
485+
) -> str:
486+
"""Applies a nary op to the given DataFrame and return the SQL representing
487+
the resulting DataFrame."""
478488
array_value = obj._block.expr
479-
op_expr = op.as_expr(l_arg, r_arg)
489+
op_expr = op.as_expr(*args)
480490
result, col_ids = array_value.compute_values([op_expr])
481491

482492
# Rename columns for deterministic golden SQL results.
483493
assert len(col_ids) == 1
484-
result = result.rename_columns({col_ids[0]: l_arg}).select_columns([l_arg])
494+
result = result.rename_columns({col_ids[0]: args[0]}).select_columns([args[0]])
485495

486496
sql = result.session._executor.to_sql(result, enable_cache=False)
487497
return sql
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`bytes_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(`bfcol_0` AS INT64) + BYTE_LENGTH(`bfcol_1`) AS `bfcol_2`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_2` AS `bool_col`
14+
FROM `bfcte_1`
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`,
6+
`string_col` AS `bfcol_3`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
STRUCT(
12+
`bfcol_0` AS bool_col,
13+
`bfcol_1` AS int64_col,
14+
`bfcol_2` AS float64_col,
15+
`bfcol_3` AS string_col
16+
) AS `bfcol_4`
17+
FROM `bfcte_0`
18+
)
19+
SELECT
20+
`bfcol_4` AS `result_col`
21+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,17 @@ def test_notnull(scalar_types_df: bpd.DataFrame, snapshot):
261261
snapshot.assert_match(sql, "out.sql")
262262

263263

264+
def test_sql_scalar_op(scalar_types_df: bpd.DataFrame, snapshot):
265+
bf_df = scalar_types_df[["bool_col", "bytes_col"]]
266+
sql = utils._apply_nary_op(
267+
bf_df,
268+
ops.SqlScalarOp(dtypes.INT_DTYPE, "CAST({0} AS INT64) + BYTE_LENGTH({1})"),
269+
"bool_col",
270+
"bytes_col",
271+
)
272+
snapshot.assert_match(sql, "out.sql")
273+
274+
264275
def test_map(scalar_types_df: bpd.DataFrame, snapshot):
265276
col_name = "string_col"
266277
bf_df = scalar_types_df[[col_name]]

tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,39 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
import pytest
1618

1719
from bigframes import operations as ops
20+
from bigframes.core import expression as ex
1821
import bigframes.pandas as bpd
1922
from bigframes.testing import utils
2023

2124
pytest.importorskip("pytest_snapshot")
2225

2326

27+
def _apply_nary_op(
28+
obj: bpd.DataFrame,
29+
op: ops.NaryOp,
30+
*args: typing.Union[str, ex.Expression],
31+
) -> str:
32+
"""Applies a nary op to the given DataFrame and return the SQL representing
33+
the resulting DataFrame."""
34+
array_value = obj._block.expr
35+
op_expr = op.as_expr(*args)
36+
result, col_ids = array_value.compute_values([op_expr])
37+
38+
# Rename columns for deterministic golden SQL results.
39+
assert len(col_ids) == 1
40+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
41+
["result_col"]
42+
)
43+
44+
sql = result.session._executor.to_sql(result, enable_cache=False)
45+
return sql
46+
47+
2448
def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
2549
col_name = "people"
2650
bf_df = nested_structs_types_df[[col_name]]
@@ -34,3 +58,11 @@ def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
3458
sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
3559

3660
snapshot.assert_match(sql, "out.sql")
61+
62+
63+
def test_struct_op(scalar_types_df: bpd.DataFrame, snapshot):
64+
bf_df = scalar_types_df[["bool_col", "int64_col", "float64_col", "string_col"]]
65+
op = ops.StructOp(column_names=tuple(bf_df.columns.tolist()))
66+
sql = _apply_nary_op(bf_df, op, *bf_df.columns.tolist())
67+
68+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)