Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core Refactorization + Reconcilers.fit #128

Merged
merged 8 commits into from
Dec 13, 2022
39 changes: 24 additions & 15 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,37 @@
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.BottomUp._get_PW_matrices': ( 'methods.html#bottomup._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.BottomUp.reconcile': ( 'methods.html#bottomup.reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.BottomUp.fit': ( 'methods.html#bottomup.fit',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.BottomUp.fit_predict': ( 'methods.html#bottomup.fit_predict',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM': ('methods.html#erm', 'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM.__init__': ( 'methods.html#erm.__init__',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM._get_PW_matrices': ( 'methods.html#erm._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM.reconcile': ( 'methods.html#erm.reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM.fit': ( 'methods.html#erm.fit',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.ERM.fit_predict': ( 'methods.html#erm.fit_predict',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MiddleOut': ( 'methods.html#middleout',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MiddleOut.__init__': ( 'methods.html#middleout.__init__',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MiddleOut._get_PW_matrices': ( 'methods.html#middleout._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MiddleOut.reconcile': ( 'methods.html#middleout.reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MiddleOut.fit_predict': ( 'methods.html#middleout.fit_predict',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace': ( 'methods.html#mintrace',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace.__init__': ( 'methods.html#mintrace.__init__',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace._get_PW_matrices': ( 'methods.html#mintrace._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace.reconcile': ( 'methods.html#mintrace.reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace.fit': ( 'methods.html#mintrace.fit',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.MinTrace.fit_predict': ( 'methods.html#mintrace.fit_predict',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.OptimalCombination': ( 'methods.html#optimalcombination',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.OptimalCombination.__init__': ( 'methods.html#optimalcombination.__init__',
Expand All @@ -60,20 +66,20 @@
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.TopDown._get_PW_matrices': ( 'methods.html#topdown._get_pw_matrices',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.TopDown.reconcile': ( 'methods.html#topdown.reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.TopDown.fit': ( 'methods.html#topdown.fit',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.TopDown.fit_predict': ( 'methods.html#topdown.fit_predict',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._get_child_nodes': ( 'methods.html#_get_child_nodes',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._get_sampler': ( 'methods.html#_get_sampler',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._reconcile': ( 'methods.html#_reconcile',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._reconcile_fcst_proportions': ( 'methods.html#_reconcile_fcst_proportions',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.cov2corr': ( 'methods.html#cov2corr',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.crossprod': ( 'methods.html#crossprod',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.is_strictly_hierarchical': ( 'methods.html#is_strictly_hierarchical',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.lasso': ( 'methods.html#lasso',
'hierarchicalforecast/methods.py')},
'hierarchicalforecast.probabilistic_methods': { 'hierarchicalforecast.probabilistic_methods.Bootstrap': ( 'probabilistic_methods.html#bootstrap',
Expand Down Expand Up @@ -121,4 +127,7 @@
'hierarchicalforecast.utils._to_summing_matrix': ( 'utils.html#_to_summing_matrix',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.aggregate': ( 'utils.html#aggregate',
'hierarchicalforecast/utils.py')}}}
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.cov2corr': ('utils.html#cov2corr', 'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.is_strictly_hierarchical': ( 'utils.html#is_strictly_hierarchical',
'hierarchicalforecast/utils.py')}}}
81 changes: 27 additions & 54 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@

# %% ../nbs/core.ipynb 3
import re
import gc
from inspect import signature
from scipy.stats import norm
from typing import Callable, Dict, List, Optional

import numpy as np
import pandas as pd

from .probabilistic_methods import Normality, Bootstrap, PERMBU

# %% ../nbs/core.ipynb 5
def _build_fn_name(fn) -> str:
fn_name = type(fn).__name__
Expand All @@ -35,7 +34,7 @@ def _build_fn_name(fn) -> str:
return fn_name

# %% ../nbs/core.ipynb 9
def _reverse_engineer_sigmah(Y_hat_df, y_hat_model, model_name, uids):
def _reverse_engineer_sigmah(Y_hat_df, y_hat, model_name, uids):
"""
This function assumes that the model creates prediction intervals
under a normality assumption with the following the Equation:
Expand All @@ -59,7 +58,7 @@ def _reverse_engineer_sigmah(Y_hat_df, y_hat_model, model_name, uids):
level_col = float(level_col[0])
z = norm.ppf(0.5 + level_col / 200)
sigmah = Y_hat_df.pivot(columns='ds', values=pi_col).loc[uids].values
sigmah = sign * (sigmah - y_hat_model) / z
sigmah = sign * (sigmah - y_hat) / z

return sigmah

Expand Down Expand Up @@ -168,68 +167,42 @@ def reconcile(self,
has_fitted = 'y_hat_insample' in signature(reconcile_fn).parameters
has_level = 'level' in signature(reconcile_fn).parameters

# TODO: maybe sort in advance by uids and avoid .loc[uids]
# This change affects y_hat_model, y_insample, y_hat_insample, sigmah
# change pivot for df.values and reshapes.
for model_name in model_names:
# Remember: pivot sorts uid
y_hat_model = Y_hat_df.pivot(columns='ds', values=model_name).loc[uids].values
y_hat = Y_hat_df.pivot(columns='ds', values=model_name).loc[uids].values
reconciler_args['y_hat'] = y_hat

# Recover sigmah and add it to reconciler_args
if has_level and level is not None and intervals_method in ['normality', 'permbu']:
sigmah = _reverse_engineer_sigmah(Y_hat_df=Y_hat_df,
y_hat_model=y_hat_model, model_name=model_name, uids=uids)
if (self.insample and has_fitted) or intervals_method in ['bootstrap', 'permbu']:
y_hat_insample = Y_df.pivot(columns='ds', values=model_name).loc[uids].values
y_hat_insample = y_hat_insample.astype(np.float32)
reconciler_args['y_hat_insample'] = y_hat_insample

if has_level and (level is not None):
reconciler_args['level'] = level
if intervals_method == 'permbu':
y_hat_insample = Y_df.pivot(columns='ds', values=model_name).loc[uids].values
y_hat_insample = y_hat_insample.astype(np.float32)
reconciler_args['sampler'] = PERMBU(
S=reconciler_args['S'],
y_hat=y_hat_model,
tags=reconciler_args['tags'],
y_insample=reconciler_args['y_insample'],
y_hat_insample=y_hat_insample,
sigmah=sigmah,
num_samples=None
)
elif intervals_method == 'normality':
reconciler_args['sampler'] = Normality(S=reconciler_args['S'], sigmah=sigmah)
if (self.insample and has_fitted) or (intervals_method in ['bootstrap']):
if model_name in Y_df:
y_hat_insample = Y_df.pivot(columns='ds', values=model_name).loc[uids].values
y_hat_insample = y_hat_insample.astype(np.float32)
if has_fitted:
reconciler_args['y_hat_insample'] = y_hat_insample
if intervals_method == 'bootstrap' and has_level:
reconciler_args['sampler'] = Bootstrap(
S=reconciler_args['S'],
y_hat=y_hat_model,
y_insample=reconciler_args['y_insample'],
y_hat_insample=y_hat_insample,
num_samples=1_000
)
reconciler_args['level'] = level
else:
# some methods have the residuals argument
# but they don't need them
# ej MinTrace(method='ols')
reconciler_args['y_hat_insample'] = None

# Mean reconciliation

if intervals_method in ['normality', 'permbu']:
sigmah = _reverse_engineer_sigmah(Y_hat_df=Y_hat_df,
y_hat=y_hat, model_name=model_name, uids=uids)
reconciler_args['sigmah'] = sigmah

reconciler_args['intervals_method'] = intervals_method

# Mean and Probabilistic reconciliation
kwargs = [key for key in signature(reconcile_fn).parameters if key in reconciler_args.keys()]
kwargs = {key: reconciler_args[key] for key in kwargs}
fcsts_model = reconcile_fn(y_hat=y_hat_model, **kwargs)

# TODO: instantiate prob reconcilers after mean reconc
# and use _prob_reconcile function from probabilistic_methods.py
# this will greatly simplify code above, and improve its readability
# this will require outputs of reconcile_fn to include P, W
fcsts_model = reconcile_fn(**kwargs)

# Parse final outputs
fcsts[f'{model_name}/{reconcile_fn_name}'] = fcsts_model['mean'].flatten()
if intervals_method in ['bootstrap', 'normality', 'permbu'] and level is not None:
for lv in level:
fcsts[f'{model_name}/{reconcile_fn_name}-lo-{lv}'] = fcsts_model[f'lo-{lv}'].flatten()
fcsts[f'{model_name}/{reconcile_fn_name}-hi-{lv}'] = fcsts_model[f'hi-{lv}'].flatten()
del reconciler_args['level']
del reconciler_args['sampler']

if self.insample and has_fitted:
del reconciler_args['y_hat_insample']
gc.collect()

return fcsts
Loading