Skip to content

Commit 6c81445

Browse files
committed
refactor: add struct_op for the sqlglot compiler
1 parent 5ec3cc0 commit 6c81445

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

bigframes/core/compile/sqlglot/expressions/struct_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2525
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2626

27+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2728
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2829

2930

@@ -40,3 +41,13 @@ def _(expr: TypedExpr, op: ops.StructFieldOp) -> sge.Expression:
4041
this=sge.to_identifier(name, quoted=True),
4142
catalog=expr.expr,
4243
)
44+
45+
46+
@register_nary_op(ops.StructOp, pass_op=True)
47+
def _(*exprs: TypedExpr, op: ops.StructOp) -> sge.Struct:
48+
return sge.Struct(
49+
expressions=[
50+
sge.PropertyEQ(this=sge.to_identifier(col), expression=expr.expr)
51+
for col, expr in zip(op.column_names, exprs)
52+
]
53+
)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`float64_col` AS `bfcol_2`,
6+
`string_col` AS `bfcol_3`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
STRUCT(
12+
`bfcol_0` AS bool_col,
13+
`bfcol_1` AS int64_col,
14+
`bfcol_2` AS float64_col,
15+
`bfcol_3` AS string_col
16+
) AS `bfcol_4`
17+
FROM `bfcte_0`
18+
)
19+
SELECT
20+
`bfcol_4` AS `result_col`
21+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,39 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import typing
16+
1517
import pytest
1618

1719
from bigframes import operations as ops
20+
from bigframes.core import expression as ex
1821
import bigframes.pandas as bpd
1922
from bigframes.testing import utils
2023

2124
pytest.importorskip("pytest_snapshot")
2225

2326

27+
def _apply_nary_op(
28+
obj: bpd.DataFrame,
29+
op: ops.NaryOp,
30+
*args: typing.Union[str, ex.Expression],
31+
) -> str:
32+
"""Applies a nary op to the given DataFrame and return the SQL representing
33+
the resulting DataFrame."""
34+
array_value = obj._block.expr
35+
op_expr = op.as_expr(*args)
36+
result, col_ids = array_value.compute_values([op_expr])
37+
38+
# Rename columns for deterministic golden SQL results.
39+
assert len(col_ids) == 1
40+
result = result.rename_columns({col_ids[0]: "result_col"}).select_columns(
41+
["result_col"]
42+
)
43+
44+
sql = result.session._executor.to_sql(result, enable_cache=False)
45+
return sql
46+
47+
2448
def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
2549
col_name = "people"
2650
bf_df = nested_structs_types_df[[col_name]]
@@ -34,3 +58,11 @@ def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
3458
sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
3559

3660
snapshot.assert_match(sql, "out.sql")
61+
62+
63+
def test_struct_op(scalar_types_df: bpd.DataFrame, snapshot):
64+
bf_df = scalar_types_df[["bool_col", "int64_col", "float64_col", "string_col"]]
65+
op = ops.StructOp(column_names=tuple(bf_df.columns.tolist()))
66+
sql = _apply_nary_op(bf_df, op, *bf_df.columns.tolist())
67+
68+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)