Skip to content

Commit

Permalink
[Nursery] Cop-DeepAR: Fix evaluation at all levels (#2944)
Browse files Browse the repository at this point in the history
  • Loading branch information
rshyamsundar authored Aug 3, 2023
1 parent 851f3be commit 7ab786b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from itertools import tee
from typing import List

import numpy as np
Expand Down Expand Up @@ -87,27 +88,48 @@ def evaluate_forecasts_at_all_levels(
evaluator,
metrics: List[str] = ["mean_wQuantileLoss"],
):
# Order of forecasts obtained e.g., in case of the hierarchy: [6M, 2M, 1M]
# (ts_1_6M, ts_1_2M, ts_1_1M, ts_2_6M, ts_2_2M, ts_2_1M, ...)
forecast_at_all_levels_unpacked_it = unpack_forecasts(
forecast_at_all_levels_it=forecast_at_all_levels_it,
temporal_hierarchy=temporal_hierarchy,
target_temporal_hierarchy=temporal_hierarchy,
)

# First get item metrics for all time series for all frequencies; these are per time series metrics.
# Then we aggregate the metrics by slicing according to the hierarchy.
# `metrics_per_ts` is a dataframe where columns contain all item metrics;
# number of rows = num_levels x num_ts, with the row ordering:
# (ts_1_6M, ts_1_2M, ts_1_1M, ts_2_6M, ts_2_2M, ts_2_1M, ...)
_, metrics_per_ts = evaluator(
ts_iterator=test_ts_at_all_levels_it,
fcst_iterator=forecast_at_all_levels_unpacked_it,
num_levels = len(temporal_hierarchy.agg_multiples)

# In one go, we can get item metrics for time series that have the same frequency.
# So we create `num_levels` copies of the iterator and obtain the item metrics
# for each level independently.
forecast_at_all_levels_unpacked_it_set = tee(
forecast_at_all_levels_unpacked_it,
num_levels,
)
test_ts_at_all_levels_it_set = tee(test_ts_at_all_levels_it, num_levels)

# Since forecasts for all granularities are in the same iterable,
# we need a way to iterate through forecasts skipping some elements.
def skip_iter(it, num_skips: int, offset: int):
for _ in range(offset):
next(it)
for item in it:
for _ in range(num_skips):
next(it, None) # None: in case the `it` is already exhausted.
yield item

num_levels = len(temporal_hierarchy.agg_multiples)
metrics_to_return = {}
for level in range(num_levels):
agg_metrics_level, _ = evaluator.get_aggregate_metrics(
metrics_per_ts.iloc[level:None:num_levels]
agg_metrics_level, metrics_per_ts_level = evaluator(
ts_iterator=skip_iter(
test_ts_at_all_levels_it_set[level],
num_skips=num_levels - 1,
offset=level,
),
fcst_iterator=skip_iter(
forecast_at_all_levels_unpacked_it_set[level],
num_skips=num_levels - 1,
offset=level,
),
)

for metric_name in metrics:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
)


EVALUATE_ALL_LEVELS = True


def main():
dataset = get_dataset("exchange_rate", regenerate=False)

Expand Down

0 comments on commit 7ab786b

Please sign in to comment.