Skip to content

Commit

Permalink
SNOW-150938 : Fix df.analytics.time_series_agg() function aggregation…
Browse files Browse the repository at this point in the history
… bug (#1851)

<!---
Please answer these questions before creating your pull request. Thanks!
--->



1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-150938

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

df.analytics.time_Series_agg() was aggregating by sliding point instead
of timestamp column. This resulted in duplicated calculations.
  • Loading branch information
sfc-gh-rsureshbabu authored Oct 19, 2024
1 parent 92a7365 commit 8439f86
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#### Bug Fixes

- Fixed a bug where the automatic cleanup of temporary tables could interfere with the results of async query execution.
- Fixed a bug in `DataFrame.analytics.time_series_agg` function to handle multiple data points in same sliding interval.

### Snowpark pandas API Updates

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def time_series_agg(
... )
>>> res.show()
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"PRODUCTKEY" |"SLIDING_POINT" |"SALESAMOUNT" |"ORDERDATE" |"SUM_SALESAMOUNT_1D" |"MAX_SALESAMOUNT_1D" |"SUM_SALESAMOUNT_-1D" |"MAX_SALESAMOUNT_-1D" |
|"PRODUCTKEY" |"ORDERDATE" |"SALESAMOUNT" |"SLIDING_POINT" |"SUM_SALESAMOUNT_1D" |"MAX_SALESAMOUNT_1D" |"SUM_SALESAMOUNT_-1D" |"MAX_SALESAMOUNT_-1D" |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|101 |2023-01-01 00:00:00 |200 |2023-01-01 00:00:00 |300 |200 |200 |200 |
|101 |2023-01-02 00:00:00 |100 |2023-01-02 00:00:00 |400 |300 |300 |200 |
Expand Down Expand Up @@ -664,7 +664,7 @@ def time_series_agg(
).filter(col(f"{sliding_point_col}B") <= window_end)

# Peform final aggregations.
group_by_cols = group_by + [sliding_point_col]
group_by_cols = group_by + [time_col]
result_df = self._perform_window_aggregations(
result_df,
self_joined_df,
Expand Down
103 changes: 79 additions & 24 deletions tests/integ/test_df_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ def custom_formatter(input_col, agg, window):
# Define the expected data
expected_data = {
"PRODUCTKEY": [101, 101, 101, 102],
"SLIDING_POINT": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"SALESAMOUNT": [200, 100, 300, 250],
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"SALESAMOUNT": [200, 100, 300, 250],
"SLIDING_POINT": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"SUM_SALESAMOUNT_1D": [300, 400, 300, 250],
"MAX_SALESAMOUNT_1D": [200, 300, 300, 250],
"SUM_SALESAMOUNT_-1D": [200, 300, 400, 250],
Expand All @@ -466,6 +466,61 @@ def custom_formatter(input_col, agg, window):
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_time_series_aggregation_grouping_bug_fix(session):
data = [
["2024-02-01 00:00:00", "product_A", "transaction_1", 10],
["2024-02-15 00:00:00", "product_A", "transaction_2", 15],
["2024-02-15 08:00:00", "product_A", "transaction_3", 7],
["2024-02-17 00:00:00", "product_A", "transaction_4", 3],
]
df = session.create_dataframe(data).to_df(
"TS", "PRODUCT_ID", "TRANSACTION_ID", "QUANTITY"
)

res = df.analytics.time_series_agg(
time_col="TS",
group_by=["PRODUCT_ID"],
aggs={"QUANTITY": ["SUM"]},
windows=["-1D", "-7D"],
sliding_interval="1D",
)

expected_data = {
"PRODUCT_ID": ["product_A", "product_A", "product_A", "product_A"],
"TS": [
"2024-02-01 00:00:00",
"2024-02-15 00:00:00",
"2024-02-15 08:00:00",
"2024-02-17 00:00:00",
],
"TRANSACTION_ID": [
"transaction_1",
"transaction_2",
"transaction_3",
"transaction_4",
],
"QUANTITY": [10, 15, 7, 3],
"SLIDING_POINT": [
"2024-02-01 00:00:00",
"2024-02-15 00:00:00",
"2024-02-15 00:00:00",
"2024-02-17 00:00:00",
],
"QUANTITY_SUM_-1D": [10, 22, 22, 3],
"QUANTITY_SUM_-7D": [10, 22, 22, 25],
}

expected_df = pd.DataFrame(expected_data)

expected_df["SLIDING_POINT"] = pd.to_datetime(expected_df["SLIDING_POINT"])

# Compare the result to the expected DataFrame
assert_frame_equal(
res.order_by("TS").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
Expand Down Expand Up @@ -502,17 +557,6 @@ def custom_formatter(input_col, agg, window):

expected_data = {
"PRODUCTKEY": [101, 101, 101, 101, 102, 102, 102, 102],
"SLIDING_POINT": [
"2023-01-01",
"2023-02-01",
"2023-03-01",
"2023-04-01",
"2023-02-01",
"2023-03-01",
"2023-04-01",
"2023-05-01",
],
"SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450],
"ORDERDATE": [
"2023-01-15",
"2023-02-15",
Expand All @@ -523,6 +567,17 @@ def custom_formatter(input_col, agg, window):
"2023-03-20",
"2023-04-20",
],
"SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450],
"SLIDING_POINT": [
"2023-01-01",
"2023-02-01",
"2023-03-01",
"2023-04-01",
"2023-02-01",
"2023-03-01",
"2023-04-01",
"2023-05-01",
],
"SUM_SALESAMOUNT_-2mm": [100, 300, 600, 900, 150, 400, 750, 1050],
"MAX_SALESAMOUNT_-2mm": [100, 200, 300, 400, 150, 250, 350, 450],
}
Expand Down Expand Up @@ -573,17 +628,6 @@ def custom_formatter(input_col, agg, window):
# Calculated expected data for 2Y window with 1Y sliding interval
expected_data = {
"PRODUCTKEY": [101, 101, 101, 101, 102, 102, 102, 102],
"SLIDING_POINT": [
"2021-01-01",
"2022-01-01",
"2023-01-01",
"2024-01-01",
"2021-01-01",
"2022-01-01",
"2023-01-01",
"2024-01-01",
],
"SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450],
"ORDERDATE": [
"2021-01-15",
"2022-01-15",
Expand All @@ -594,6 +638,17 @@ def custom_formatter(input_col, agg, window):
"2023-01-20",
"2024-01-20",
],
"SALESAMOUNT": [100, 200, 300, 400, 150, 250, 350, 450],
"SLIDING_POINT": [
"2021-01-01",
"2022-01-01",
"2023-01-01",
"2024-01-01",
"2021-01-01",
"2022-01-01",
"2023-01-01",
"2024-01-01",
],
"SUM_SALESAMOUNT_-1Y": [100, 300, 500, 700, 150, 400, 600, 800],
"MAX_SALESAMOUNT_-1Y": [100, 200, 300, 400, 150, 250, 350, 450],
}
Expand Down

0 comments on commit 8439f86

Please sign in to comment.