Skip to content

Commit

Permalink
[FEAT] Faster creation of ProbReconciler's ordered levels (#137)
Browse files Browse the repository at this point in the history
* Added MSSE, sCRPS, and EScore placeholder

* Default [N,H,samples] shape for easier future handling

* Default [N,H,samples] shape for easier future handling

* Working energy score

* Improved docstrings

* Added Exception messages and QL unit test

* TourismL end to end experiment placeholder

* Partial y_hat_insample protections

* adding tqdm dependency

* adding tqdm to reconcilers for loop

* adding timer utility

* Informative Except Y_hat_df and Y_df column mismatch message

* Cleaning notebooks
  • Loading branch information
kdgutier authored Dec 17, 2022
1 parent 8baabe0 commit 81abd7b
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 22 deletions.
6 changes: 6 additions & 0 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Bootstrap.get_prediction_levels': ( 'probabilistic_methods.html#bootstrap.get_prediction_levels',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Bootstrap.get_prediction_quantiles': ( 'probabilistic_methods.html#bootstrap.get_prediction_quantiles',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Bootstrap.get_samples': ( 'probabilistic_methods.html#bootstrap.get_samples',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Normality': ( 'probabilistic_methods.html#normality',
Expand All @@ -118,6 +120,8 @@
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Normality.get_prediction_levels': ( 'probabilistic_methods.html#normality.get_prediction_levels',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Normality.get_prediction_quantiles': ( 'probabilistic_methods.html#normality.get_prediction_quantiles',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.Normality.get_samples': ( 'probabilistic_methods.html#normality.get_samples',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.PERMBU': ( 'probabilistic_methods.html#permbu',
Expand All @@ -134,6 +138,8 @@
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.PERMBU.get_prediction_levels': ( 'probabilistic_methods.html#permbu.get_prediction_levels',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.PERMBU.get_prediction_quantiles': ( 'probabilistic_methods.html#permbu.get_prediction_quantiles',
'hierarchicalforecast/probabilistic_methods.py'),
'hierarchicalforecast.probabilistic_methods.PERMBU.get_samples': ( 'probabilistic_methods.html#permbu.get_samples',
'hierarchicalforecast/probabilistic_methods.py')},
'hierarchicalforecast.utils': { 'hierarchicalforecast.utils.CodeTimer': ( 'utils.html#codetimer',
Expand Down
24 changes: 15 additions & 9 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def reconcile(self,
# TODO: Complete y_hat_insample protection
if intervals_method in ['bootstrap', 'permbu']:
if not (set(model_names) <= set(Y_df.columns)):
raise Exception('Check `Y_hat_df`, `Y_df` columns difference')
raise Exception('Check `Y_hat_df`s models are included in `Y_df` columns')

# Same Y_hat_df/S_df/Y_df's unique_id order to prevent errors
S_ = S.loc[uids]
Expand Down Expand Up @@ -206,15 +206,21 @@ def reconcile(self,
# Parse final outputs
fcsts[f'{model_name}/{reconcile_fn_name}'] = fcsts_model['mean'].flatten()
if intervals_method in ['bootstrap', 'normality', 'permbu'] and level is not None:
for lv in level:
fcsts[f'{model_name}/{reconcile_fn_name}-lo-{lv}'] = fcsts_model[f'lo-{lv}'].flatten()
fcsts[f'{model_name}/{reconcile_fn_name}-hi-{lv}'] = fcsts_model[f'hi-{lv}'].flatten()

end = time.time()
self.execution_times[f'{model_name}/{reconcile_fn_name}'] = (end - start)

level.sort()
hi_names = [f'{model_name}/{reconcile_fn_name}-hi-{lv}' for lv in level]
lo_names = [f'{model_name}/{reconcile_fn_name}-lo-{lv}' for lv in reversed(level)]
sorted_quantiles = np.reshape(fcsts_model['quantiles'], (len(fcsts),-1))
intervals_df = pd.DataFrame(sorted_quantiles,
columns=(lo_names+hi_names), index=fcsts.index)
fcsts = pd.concat([fcsts, intervals_df], axis=1)

del sorted_quantiles
del intervals_df
if self.insample and has_fitted:
del reconciler_args['y_hat_insample']
del y_hat_insample
gc.collect()

end = time.time()
self.execution_times[f'{model_name}/{reconcile_fn_name}'] = (end - start)

return fcsts
7 changes: 6 additions & 1 deletion hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def _reconcile(S: np.ndarray,

# Probabilistic reconciliation
if (level is not None) and (sampler is not None):
res = sampler.get_prediction_levels(res=res, level=level)
# Update results dictionary within
# Vectorized quantiles
quantiles = np.concatenate(
[[(100 - lv) / 200, ((100 - lv) / 200) + lv / 100] for lv in level])
quantiles = np.sort(quantiles)
res = sampler.get_prediction_quantiles(res, quantiles)

return res

Expand Down
26 changes: 26 additions & 0 deletions hierarchicalforecast/probabilistic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def get_prediction_levels(self, res, level):
res[f'hi-{lv}'] = res['mean'] + zs * self.sigmah_rec
return res

def get_prediction_quantiles(self, res, quantiles):
""" Adds reconciled forecast quantiles to results dictionary """
# [N,H,None] + [None None,Q] * [N,H,None] -> [N,H,Q]
z = norm.ppf(quantiles)
res['sigmah'] = self.sigmah_rec
res['quantiles'] = res['mean'][:,:,None] + z[None,None,:] * self.sigmah_rec[:,:,None]
return res

# %% ../nbs/probabilistic_methods.ipynb 10
class Bootstrap:
""" Bootstrap Probabilistic Reconciliation Class.
Expand Down Expand Up @@ -187,6 +195,15 @@ def get_prediction_levels(self, res, level):
res[f'hi-{lv}'] = np.quantile(samples, max_q, axis=2)
return res

def get_prediction_quantiles(self, res, quantiles):
""" Adds reconciled forecast quantiles to results dictionary """
samples = self.get_samples(num_samples=self.num_samples)

# [Q, N, H] -> [N, H, Q]
sample_quantiles = np.quantile(samples, quantiles, axis=2)
res['quantiles'] = sample_quantiles.transpose((1, 2, 0))
return res

# %% ../nbs/probabilistic_methods.ipynb 14
class PERMBU:
""" PERMBU Probabilistic Reconciliation Class.
Expand Down Expand Up @@ -400,3 +417,12 @@ def get_prediction_levels(self, res, level):
res[f'lo-{lv}'] = np.quantile(samples, min_q, axis=2)
res[f'hi-{lv}'] = np.quantile(samples, max_q, axis=2)
return res

def get_prediction_quantiles(self, res, quantiles):
""" Adds reconciled forecast quantiles to results dictionary """
samples = self.get_samples(num_samples=self.num_samples)

# [Q, N, H] -> [N, H, Q]
sample_quantiles = np.quantile(samples, quantiles, axis=2)
res['quantiles'] = sample_quantiles.transpose((1, 2, 0))
return res
24 changes: 15 additions & 9 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
" # TODO: Complete y_hat_insample protection\n",
" if intervals_method in ['bootstrap', 'permbu']:\n",
" if not (set(model_names) <= set(Y_df.columns)):\n",
" raise Exception('Check `Y_hat_df`, `Y_df` columns difference')\n",
" raise Exception('Check `Y_hat_df`s models are included in `Y_df` columns')\n",
"\n",
" # Same Y_hat_df/S_df/Y_df's unique_id order to prevent errors\n",
" S_ = S.loc[uids]\n",
Expand Down Expand Up @@ -306,17 +306,23 @@
" # Parse final outputs\n",
" fcsts[f'{model_name}/{reconcile_fn_name}'] = fcsts_model['mean'].flatten()\n",
" if intervals_method in ['bootstrap', 'normality', 'permbu'] and level is not None:\n",
" for lv in level:\n",
" fcsts[f'{model_name}/{reconcile_fn_name}-lo-{lv}'] = fcsts_model[f'lo-{lv}'].flatten()\n",
" fcsts[f'{model_name}/{reconcile_fn_name}-hi-{lv}'] = fcsts_model[f'hi-{lv}'].flatten()\n",
" \n",
" end = time.time()\n",
" self.execution_times[f'{model_name}/{reconcile_fn_name}'] = (end - start)\n",
"\n",
" level.sort()\n",
" hi_names = [f'{model_name}/{reconcile_fn_name}-hi-{lv}' for lv in level]\n",
" lo_names = [f'{model_name}/{reconcile_fn_name}-lo-{lv}' for lv in reversed(level)]\n",
" sorted_quantiles = np.reshape(fcsts_model['quantiles'], (len(fcsts),-1))\n",
" intervals_df = pd.DataFrame(sorted_quantiles, \n",
" columns=(lo_names+hi_names), index=fcsts.index)\n",
" fcsts = pd.concat([fcsts, intervals_df], axis=1)\n",
"\n",
" del sorted_quantiles\n",
" del intervals_df\n",
" if self.insample and has_fitted:\n",
" del reconciler_args['y_hat_insample']\n",
" del y_hat_insample\n",
" gc.collect()\n",
"\n",
" end = time.time()\n",
" self.execution_times[f'{model_name}/{reconcile_fn_name}'] = (end - start)\n",
"\n",
" return fcsts"
]
},
Expand Down
11 changes: 8 additions & 3 deletions nbs/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@
"\n",
" # Probabilistic reconciliation\n",
" if (level is not None) and (sampler is not None):\n",
" res = sampler.get_prediction_levels(res=res, level=level)\n",
" # Update results dictionary within\n",
" # Vectorized quantiles\n",
" quantiles = np.concatenate(\n",
" [[(100 - lv) / 200, ((100 - lv) / 200) + lv / 100] for lv in level])\n",
" quantiles = np.sort(quantiles)\n",
" res = sampler.get_prediction_quantiles(res, quantiles)\n",
"\n",
" return res"
]
Expand Down Expand Up @@ -1673,15 +1678,15 @@
"cls_bottom_up = BottomUp()\n",
"bu_bootstrap_intervals = cls_bottom_up(**reconciler_args)\n",
"test_eq(\n",
" ['mean', 'sigmah', 'lo-80', 'hi-80', 'lo-90', 'hi-90'],\n",
" ['mean', 'sigmah', 'quantiles'],\n",
" list(bu_bootstrap_intervals.keys())\n",
")\n",
"\n",
"# test PERMBU interval's names\n",
"reconciler_args['intervals_method'] = 'permbu'\n",
"bu_permbu_intervals = cls_bottom_up(**reconciler_args)\n",
"test_eq(\n",
" ['mean', 'lo-80', 'hi-80', 'lo-90', 'hi-90'],\n",
" ['mean', 'quantiles'],\n",
" list(bu_permbu_intervals.keys())\n",
")"
]
Expand Down
37 changes: 37 additions & 0 deletions nbs/probabilistic_methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@
" for zs, lv in zip(z, level):\n",
" res[f'lo-{lv}'] = res['mean'] - zs * self.sigmah_rec\n",
" res[f'hi-{lv}'] = res['mean'] + zs * self.sigmah_rec\n",
" return res\n",
"\n",
" def get_prediction_quantiles(self, res, quantiles):\n",
" \"\"\" Adds reconciled forecast quantiles to results dictionary \"\"\"\n",
" # [N,H,None] + [None None,Q] * [N,H,None] -> [N,H,Q]\n",
" z = norm.ppf(quantiles)\n",
" res['sigmah'] = self.sigmah_rec\n",
" res['quantiles'] = res['mean'][:,:,None] + z[None,None,:] * self.sigmah_rec[:,:,None]\n",
" return res"
]
},
Expand Down Expand Up @@ -273,6 +281,15 @@
" max_q = min_q + lv / 100\n",
" res[f'lo-{lv}'] = np.quantile(samples, min_q, axis=2)\n",
" res[f'hi-{lv}'] = np.quantile(samples, max_q, axis=2)\n",
" return res\n",
"\n",
" def get_prediction_quantiles(self, res, quantiles):\n",
" \"\"\" Adds reconciled forecast quantiles to results dictionary \"\"\"\n",
" samples = self.get_samples(num_samples=self.num_samples)\n",
"\n",
" # [Q, N, H] -> [N, H, Q]\n",
" sample_quantiles = np.quantile(samples, quantiles, axis=2)\n",
" res['quantiles'] = sample_quantiles.transpose((1, 2, 0))\n",
" return res"
]
},
Expand Down Expand Up @@ -519,6 +536,15 @@
" max_q = min_q + lv / 100\n",
" res[f'lo-{lv}'] = np.quantile(samples, min_q, axis=2)\n",
" res[f'hi-{lv}'] = np.quantile(samples, max_q, axis=2)\n",
" return res\n",
"\n",
" def get_prediction_quantiles(self, res, quantiles):\n",
" \"\"\" Adds reconciled forecast quantiles to results dictionary \"\"\"\n",
" samples = self.get_samples(num_samples=self.num_samples)\n",
"\n",
" # [Q, N, H] -> [N, H, Q]\n",
" sample_quantiles = np.quantile(samples, quantiles, axis=2)\n",
" res['quantiles'] = sample_quantiles.transpose((1, 2, 0))\n",
" return res"
]
},
Expand Down Expand Up @@ -671,6 +697,17 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"level = [1,2,3,0]\n",
"level.sort()\n",
"level"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 81abd7b

Please sign in to comment.