Skip to content

Commit 636578e

Browse files
committed
refactor: enable engine tests for datetime ops
1 parent 5ec3cc0 commit 636578e

File tree

7 files changed

+107
-26
lines changed

7 files changed

+107
-26
lines changed

bigframes/core/compile/sqlglot/expressions/date_ops.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ def _(expr: TypedExpr) -> sge.Expression:
3535

3636
@register_unary_op(ops.dayofweek_op)
3737
def _(expr: TypedExpr) -> sge.Expression:
38-
# Adjust the 1-based day-of-week index (from SQL) to a 0-based index.
39-
return sge.Extract(
40-
this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr
41-
) - sge.convert(1)
38+
return dayofweek_op_impl(expr)
4239

4340

4441
@register_unary_op(ops.dayofyear_op)
@@ -48,7 +45,8 @@ def _(expr: TypedExpr) -> sge.Expression:
4845

4946
@register_unary_op(ops.iso_day_op)
5047
def _(expr: TypedExpr) -> sge.Expression:
51-
return sge.Extract(this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr)
48+
# Plus 1 because iso day of week uses 1-based indexing
49+
return dayofweek_op_impl(expr) + sge.convert(1)
5250

5351

5452
@register_unary_op(ops.iso_week_op)
@@ -59,3 +57,15 @@ def _(expr: TypedExpr) -> sge.Expression:
5957
@register_unary_op(ops.iso_year_op)
6058
def _(expr: TypedExpr) -> sge.Expression:
6159
return sge.Extract(this=sge.Identifier(this="ISOYEAR"), expression=expr.expr)
60+
61+
62+
# Helpers
63+
def dayofweek_op_impl(expr: TypedExpr) -> sge.Expression:
64+
# Adjust the 1-based day-of-week index (from SQL) to a 0-based index.
65+
extract_expr = sge.Extract(
66+
this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr
67+
)
68+
return sge.Cast(
69+
this=sge.Mod(this=extract_expr + sge.convert(5), expression=sge.convert(7)),
70+
to="INT64",
71+
)

bigframes/core/compile/sqlglot/expressions/datetime_ops.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,28 @@
2525

2626
@register_unary_op(ops.FloorDtOp, pass_op=True)
2727
def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
28-
# TODO: Remove this method when it is covered by ops.FloorOp
29-
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=op.freq))
28+
pandas_to_bq_freq_map = {
29+
"Y": "YEAR",
30+
"Q": "QUARTER",
31+
"M": "MONTH",
32+
"W": "WEEK(MONDAY)",
33+
"D": "DAY",
34+
"h": "HOUR",
35+
"min": "MINUTE",
36+
"s": "SECOND",
37+
"ms": "MILLISECOND",
38+
"us": "MICROSECOND",
39+
"ns": "NANOSECOND",
40+
}
41+
if op.freq not in pandas_to_bq_freq_map.keys():
42+
raise NotImplementedError(
43+
f"Unsupported freq paramater: {op.freq}"
44+
+ " Supported freq parameters are: "
45+
+ ",".join(pandas_to_bq_freq_map.keys())
46+
)
47+
48+
bq_freq = pandas_to_bq_freq_map[op.freq]
49+
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))
3050

3151

3252
@register_unary_op(ops.hour_op)

tests/system/small/engines/test_temporal_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2626

2727

28-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
28+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
2929
def test_engines_dt_floor(scalars_array_value: array_value.ArrayValue, engine):
3030
arr, _ = scalars_array_value.compute_values(
3131
[
@@ -46,7 +46,7 @@ def test_engines_dt_floor(scalars_array_value: array_value.ArrayValue, engine):
4646
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
4747

4848

49-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
49+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5050
def test_engines_date_accessors(scalars_array_value: array_value.ArrayValue, engine):
5151
datelike_cols = ["datetime_col", "timestamp_col", "date_col"]
5252
accessors = [
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`timestamp_col` AS `bfcol_0`
3+
`date_col` AS `bfcol_0`,
4+
`datetime_col` AS `bfcol_1`,
5+
`timestamp_col` AS `bfcol_2`
46
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
57
), `bfcte_1` AS (
68
SELECT
79
*,
8-
EXTRACT(DAYOFWEEK FROM `bfcol_0`) - 1 AS `bfcol_1`
10+
CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_1`) + 5, 7) AS INT64) AS `bfcol_6`,
11+
CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_2`) + 5, 7) AS INT64) AS `bfcol_7`,
12+
CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_0`) + 5, 7) AS INT64) AS `bfcol_8`
913
FROM `bfcte_0`
1014
)
1115
SELECT
12-
`bfcol_1` AS `timestamp_col`
16+
`bfcol_6` AS `datetime_col`,
17+
`bfcol_7` AS `timestamp_col`,
18+
`bfcol_8` AS `date_col`
1319
FROM `bfcte_1`
Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`timestamp_col` AS `bfcol_0`
3+
`datetime_col` AS `bfcol_0`,
4+
`timestamp_col` AS `bfcol_1`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
78
*,
8-
TIMESTAMP_TRUNC(`bfcol_0`, D) AS `bfcol_1`
9+
TIMESTAMP_TRUNC(`bfcol_1`, MICROSECOND) AS `bfcol_2`,
10+
TIMESTAMP_TRUNC(`bfcol_1`, MILLISECOND) AS `bfcol_3`,
11+
TIMESTAMP_TRUNC(`bfcol_1`, SECOND) AS `bfcol_4`,
12+
TIMESTAMP_TRUNC(`bfcol_1`, MINUTE) AS `bfcol_5`,
13+
TIMESTAMP_TRUNC(`bfcol_1`, HOUR) AS `bfcol_6`,
14+
TIMESTAMP_TRUNC(`bfcol_1`, DAY) AS `bfcol_7`,
15+
TIMESTAMP_TRUNC(`bfcol_1`, WEEK(MONDAY)) AS `bfcol_8`,
16+
TIMESTAMP_TRUNC(`bfcol_1`, MONTH) AS `bfcol_9`,
17+
TIMESTAMP_TRUNC(`bfcol_1`, QUARTER) AS `bfcol_10`,
18+
TIMESTAMP_TRUNC(`bfcol_1`, YEAR) AS `bfcol_11`,
19+
TIMESTAMP_TRUNC(`bfcol_0`, MICROSECOND) AS `bfcol_12`,
20+
TIMESTAMP_TRUNC(`bfcol_0`, MICROSECOND) AS `bfcol_13`
921
FROM `bfcte_0`
1022
)
1123
SELECT
12-
`bfcol_1` AS `timestamp_col`
24+
`bfcol_2` AS `timestamp_col_us`,
25+
`bfcol_3` AS `timestamp_col_ms`,
26+
`bfcol_4` AS `timestamp_col_s`,
27+
`bfcol_5` AS `timestamp_col_min`,
28+
`bfcol_6` AS `timestamp_col_h`,
29+
`bfcol_7` AS `timestamp_col_D`,
30+
`bfcol_8` AS `timestamp_col_W`,
31+
`bfcol_9` AS `timestamp_col_M`,
32+
`bfcol_10` AS `timestamp_col_Q`,
33+
`bfcol_11` AS `timestamp_col_Y`,
34+
`bfcol_12` AS `datetime_col_q`,
35+
`bfcol_13` AS `datetime_col_us`
1336
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_iso_day/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
EXTRACT(DAYOFWEEK FROM `bfcol_0`) AS `bfcol_1`
8+
CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_0`) + 5, 7) AS INT64) + 1 AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ def test_day(scalar_types_df: bpd.DataFrame, snapshot):
3939

4040

4141
def test_dayofweek(scalar_types_df: bpd.DataFrame, snapshot):
42-
col_name = "timestamp_col"
43-
bf_df = scalar_types_df[[col_name]]
44-
sql = utils._apply_unary_ops(
45-
bf_df, [ops.dayofweek_op.as_expr(col_name)], [col_name]
46-
)
42+
col_names = ["datetime_col", "timestamp_col", "date_col"]
43+
bf_df = scalar_types_df[col_names]
44+
ops_map = {col_name: ops.dayofweek_op.as_expr(col_name) for col_name in col_names}
4745

46+
sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
4847
snapshot.assert_match(sql, "out.sql")
4948

5049

@@ -59,13 +58,36 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
5958

6059

6160
def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot):
61+
col_names = ["datetime_col", "timestamp_col", "date_col"]
62+
bf_df = scalar_types_df[col_names]
63+
ops_map = {
64+
"timestamp_col_us": ops.FloorDtOp("us").as_expr("timestamp_col"),
65+
"timestamp_col_ms": ops.FloorDtOp("ms").as_expr("timestamp_col"),
66+
"timestamp_col_s": ops.FloorDtOp("s").as_expr("timestamp_col"),
67+
"timestamp_col_min": ops.FloorDtOp("min").as_expr("timestamp_col"),
68+
"timestamp_col_h": ops.FloorDtOp("h").as_expr("timestamp_col"),
69+
"timestamp_col_D": ops.FloorDtOp("D").as_expr("timestamp_col"),
70+
"timestamp_col_W": ops.FloorDtOp("W").as_expr("timestamp_col"),
71+
"timestamp_col_M": ops.FloorDtOp("M").as_expr("timestamp_col"),
72+
"timestamp_col_Q": ops.FloorDtOp("Q").as_expr("timestamp_col"),
73+
"timestamp_col_Y": ops.FloorDtOp("Y").as_expr("timestamp_col"),
74+
"datetime_col_q": ops.FloorDtOp("us").as_expr("datetime_col"),
75+
"datetime_col_us": ops.FloorDtOp("us").as_expr("datetime_col"),
76+
}
77+
78+
sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
79+
snapshot.assert_match(sql, "out.sql")
80+
81+
82+
def test_floor_dt_op_invalid_freq(scalar_types_df: bpd.DataFrame):
6283
col_name = "timestamp_col"
6384
bf_df = scalar_types_df[[col_name]]
64-
sql = utils._apply_unary_ops(
65-
bf_df, [ops.FloorDtOp("D").as_expr(col_name)], [col_name]
66-
)
67-
68-
snapshot.assert_match(sql, "out.sql")
85+
with pytest.raises(
86+
NotImplementedError, match="Unsupported freq paramater: invalid"
87+
):
88+
utils._apply_unary_ops(
89+
bf_df, [ops.FloorDtOp(freq="invalid").as_expr(col_name)], [col_name]
90+
)
6991

7092

7193
def test_hour(scalar_types_df: bpd.DataFrame, snapshot):

0 commit comments

Comments
 (0)