diff --git a/bigframes/core/compile/sqlglot/expressions/date_ops.py b/bigframes/core/compile/sqlglot/expressions/date_ops.py index f5922ecc8d..be772d978d 100644 --- a/bigframes/core/compile/sqlglot/expressions/date_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/date_ops.py @@ -35,10 +35,7 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.dayofweek_op) def _(expr: TypedExpr) -> sge.Expression: - # Adjust the 1-based day-of-week index (from SQL) to a 0-based index. - return sge.Extract( - this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr - ) - sge.convert(1) + return dayofweek_op_impl(expr) @register_unary_op(ops.dayofyear_op) @@ -48,7 +45,8 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.iso_day_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.Extract(this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr) + # Plus 1 because iso day of week uses 1-based indexing + return dayofweek_op_impl(expr) + sge.convert(1) @register_unary_op(ops.iso_week_op) @@ -59,3 +57,16 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.iso_year_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="ISOYEAR"), expression=expr.expr) + + +# Helpers +def dayofweek_op_impl(expr: TypedExpr) -> sge.Expression: + # BigQuery SQL Extract(DAYOFWEEK) returns 1 for Sunday through 7 for Saturday. + # We want 0 for Monday through 6 for Sunday to be compatible with Pandas. + extract_expr = sge.Extract( + this=sge.Identifier(this="DAYOFWEEK"), expression=expr.expr + ) + return sge.Cast( + this=sge.Mod(this=extract_expr + sge.convert(5), expression=sge.convert(7)), + to="INT64", + ) diff --git a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py index 77f4233e1c..949b122a1d 100644 --- a/bigframes/core/compile/sqlglot/expressions/datetime_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/datetime_ops.py @@ -25,8 +25,28 @@ @register_unary_op(ops.FloorDtOp, pass_op=True) def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression: - # TODO: Remove this method when it is covered by ops.FloorOp - return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=op.freq)) + pandas_to_bq_freq_map = { + "Y": "YEAR", + "Q": "QUARTER", + "M": "MONTH", + "W": "WEEK(MONDAY)", + "D": "DAY", + "h": "HOUR", + "min": "MINUTE", + "s": "SECOND", + "ms": "MILLISECOND", + "us": "MICROSECOND", + "ns": "NANOSECOND", + } + if op.freq not in pandas_to_bq_freq_map.keys(): + raise NotImplementedError( + f"Unsupported freq paramater: {op.freq}" + + " Supported freq parameters are: " + + ",".join(pandas_to_bq_freq_map.keys()) + ) + + bq_freq = pandas_to_bq_freq_map[op.freq] + return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq)) @register_unary_op(ops.hour_op) diff --git a/tests/system/small/engines/test_temporal_ops.py b/tests/system/small/engines/test_temporal_ops.py index 5a39587886..66edfeddcc 100644 --- a/tests/system/small/engines/test_temporal_ops.py +++ b/tests/system/small/engines/test_temporal_ops.py @@ -25,7 +25,7 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_dt_floor(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ @@ -46,7 +46,7 @@ def test_engines_dt_floor(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_date_accessors(scalars_array_value: array_value.ArrayValue, engine): datelike_cols = ["datetime_col", "timestamp_col", "date_col"] accessors = [ diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_dayofweek/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_dayofweek/out.sql index e6c17587d0..55d3832c1f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_dayofweek/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_dayofweek/out.sql @@ -1,13 +1,19 @@ WITH `bfcte_0` AS ( SELECT - `timestamp_col` AS `bfcol_0` + `date_col` AS `bfcol_0`, + `datetime_col` AS `bfcol_1`, + `timestamp_col` AS `bfcol_2` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - EXTRACT(DAYOFWEEK FROM `bfcol_0`) - 1 AS `bfcol_1` + CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_1`) + 5, 7) AS INT64) AS `bfcol_6`, + CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_2`) + 5, 7) AS INT64) AS `bfcol_7`, + CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_0`) + 5, 7) AS INT64) AS `bfcol_8` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `timestamp_col` + `bfcol_6` AS `datetime_col`, + `bfcol_7` AS `timestamp_col`, + `bfcol_8` AS `date_col` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_floor_dt/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_floor_dt/out.sql index ad4fdb23a1..a8877f8cfa 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_floor_dt/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_floor_dt/out.sql @@ -1,13 +1,36 @@ WITH `bfcte_0` AS ( SELECT - `timestamp_col` AS `bfcol_0` + `datetime_col` AS `bfcol_0`, + `timestamp_col` AS `bfcol_1` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - TIMESTAMP_TRUNC(`bfcol_0`, D) AS `bfcol_1` + TIMESTAMP_TRUNC(`bfcol_1`, MICROSECOND) AS `bfcol_2`, + TIMESTAMP_TRUNC(`bfcol_1`, MILLISECOND) AS `bfcol_3`, + TIMESTAMP_TRUNC(`bfcol_1`, SECOND) AS `bfcol_4`, + TIMESTAMP_TRUNC(`bfcol_1`, MINUTE) AS `bfcol_5`, + TIMESTAMP_TRUNC(`bfcol_1`, HOUR) AS `bfcol_6`, + TIMESTAMP_TRUNC(`bfcol_1`, DAY) AS `bfcol_7`, + TIMESTAMP_TRUNC(`bfcol_1`, WEEK(MONDAY)) AS `bfcol_8`, + TIMESTAMP_TRUNC(`bfcol_1`, MONTH) AS `bfcol_9`, + TIMESTAMP_TRUNC(`bfcol_1`, QUARTER) AS `bfcol_10`, + TIMESTAMP_TRUNC(`bfcol_1`, YEAR) AS `bfcol_11`, + TIMESTAMP_TRUNC(`bfcol_0`, MICROSECOND) AS `bfcol_12`, + TIMESTAMP_TRUNC(`bfcol_0`, MICROSECOND) AS `bfcol_13` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `timestamp_col` + `bfcol_2` AS `timestamp_col_us`, + `bfcol_3` AS `timestamp_col_ms`, + `bfcol_4` AS `timestamp_col_s`, + `bfcol_5` AS `timestamp_col_min`, + `bfcol_6` AS `timestamp_col_h`, + `bfcol_7` AS `timestamp_col_D`, + `bfcol_8` AS `timestamp_col_W`, + `bfcol_9` AS `timestamp_col_M`, + `bfcol_10` AS `timestamp_col_Q`, + `bfcol_11` AS `timestamp_col_Y`, + `bfcol_12` AS `datetime_col_q`, + `bfcol_13` AS `datetime_col_us` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_iso_day/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_iso_day/out.sql index d389172fda..f7203fc930 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_iso_day/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_iso_day/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - EXTRACT(DAYOFWEEK FROM `bfcol_0`) AS `bfcol_1` + CAST(MOD(EXTRACT(DAYOFWEEK FROM `bfcol_0`) + 5, 7) AS INT64) + 1 AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py index 91926e7bdd..3261113806 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py @@ -39,12 +39,11 @@ def test_day(scalar_types_df: bpd.DataFrame, snapshot): def test_dayofweek(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "timestamp_col" - bf_df = scalar_types_df[[col_name]] - sql = utils._apply_unary_ops( - bf_df, [ops.dayofweek_op.as_expr(col_name)], [col_name] - ) + col_names = ["datetime_col", "timestamp_col", "date_col"] + bf_df = scalar_types_df[col_names] + ops_map = {col_name: ops.dayofweek_op.as_expr(col_name) for col_name in col_names} + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) snapshot.assert_match(sql, "out.sql") @@ -59,13 +58,38 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot): def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["datetime_col", "timestamp_col", "date_col"] + bf_df = scalar_types_df[col_names] + ops_map = { + "timestamp_col_us": ops.FloorDtOp("us").as_expr("timestamp_col"), + "timestamp_col_ms": ops.FloorDtOp("ms").as_expr("timestamp_col"), + "timestamp_col_s": ops.FloorDtOp("s").as_expr("timestamp_col"), + "timestamp_col_min": ops.FloorDtOp("min").as_expr("timestamp_col"), + "timestamp_col_h": ops.FloorDtOp("h").as_expr("timestamp_col"), + "timestamp_col_D": ops.FloorDtOp("D").as_expr("timestamp_col"), + "timestamp_col_W": ops.FloorDtOp("W").as_expr("timestamp_col"), + "timestamp_col_M": ops.FloorDtOp("M").as_expr("timestamp_col"), + "timestamp_col_Q": ops.FloorDtOp("Q").as_expr("timestamp_col"), + "timestamp_col_Y": ops.FloorDtOp("Y").as_expr("timestamp_col"), + "datetime_col_q": ops.FloorDtOp("us").as_expr("datetime_col"), + "datetime_col_us": ops.FloorDtOp("us").as_expr("datetime_col"), + } + + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + snapshot.assert_match(sql, "out.sql") + + +def test_floor_dt_op_invalid_freq(scalar_types_df: bpd.DataFrame): col_name = "timestamp_col" bf_df = scalar_types_df[[col_name]] - sql = utils._apply_unary_ops( - bf_df, [ops.FloorDtOp("D").as_expr(col_name)], [col_name] - ) - - snapshot.assert_match(sql, "out.sql") + with pytest.raises( + NotImplementedError, match="Unsupported freq paramater: invalid" + ): + utils._apply_unary_ops( + bf_df, + [ops.FloorDtOp(freq="invalid").as_expr(col_name)], # type:ignore + [col_name], + ) def test_hour(scalar_types_df: bpd.DataFrame, snapshot):