Skip to content

Commit

Permalink
feat: add optional insample values reconcile method
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza committed Oct 5, 2022
1 parent eabe7cc commit 4d997be
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 84 deletions.
19 changes: 13 additions & 6 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ class HierarchicalReconciliation:
def __init__(self,
reconcilers: List[Callable]):
self.reconcilers = reconcilers
self.insample = any([method.insample for method in reconcilers])

def reconcile(self,
Y_hat_df: pd.DataFrame,
Y_df: pd.DataFrame,
S: pd.DataFrame,
tags: Dict[str, np.ndarray],
Y_df: Optional[pd.DataFrame] = None,
level: Optional[List[int]] = None,
bootstrap: bool = False):
"""Hierarchical Reconciliation Method.
Expand Down Expand Up @@ -87,11 +88,17 @@ def reconcile(self,
# same order of Y_hat_df to prevent errors
S_ = S.loc[uids]
common_vals = dict(
y_insample = Y_df.pivot(columns='ds', values='y').loc[uids].values.astype(np.float32),
S = S_.values.astype(np.float32),
idx_bottom = S_.index.get_indexer(S.columns),
S=S_.values.astype(np.float32),
idx_bottom=S_.index.get_indexer(S.columns),
tags={key: S_.index.get_indexer(val) for key, val in tags.items()}
)
# we need insample values if
# we are using a method that requires them
# or if we are performing boostrap
if self.insample or bootstrap:
if Y_df is None:
raise Exception('you need to pass `Y_df`')
common_vals['y_insample'] = Y_df.pivot(columns='ds', values='y').loc[uids].values.astype(np.float32)
fcsts = Y_hat_df.copy()
for reconcile_fn in self.reconcilers:
reconcile_fn_name = _build_fn_name(reconcile_fn)
Expand All @@ -117,7 +124,7 @@ def reconcile(self,
sigmah = sign * (y_hat_model - sigmah) / z
common_vals['sigmah'] = sigmah
common_vals['level'] = level
if has_fitted or bootstrap:
if (self.insample and has_fitted) or bootstrap:
if model_name in Y_df:
y_hat_insample = Y_df.pivot(columns='ds', values=model_name).loc[uids].values
y_hat_insample = y_hat_insample.astype(np.float32)
Expand Down Expand Up @@ -151,6 +158,6 @@ def reconcile(self,
else:
del common_vals['bootstrap_samples']
del common_vals['bootstrap']
if has_fitted:
if self.insample and has_fitted:
del common_vals['y_hat_insample']
return fcsts
196 changes: 119 additions & 77 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@
" def __init__(self,\n",
" reconcilers: List[Callable]):\n",
" self.reconcilers = reconcilers\n",
" self.insample = any([method.insample for method in reconcilers])\n",
"\n",
" def reconcile(self, \n",
" Y_hat_df: pd.DataFrame,\n",
" Y_df: pd.DataFrame,\n",
" S: pd.DataFrame,\n",
" tags: Dict[str, np.ndarray],\n",
" Y_df: Optional[pd.DataFrame] = None,\n",
" level: Optional[List[int]] = None,\n",
" bootstrap: bool = False):\n",
" \"\"\"Hierarchical Reconciliation Method.\n",
Expand Down Expand Up @@ -147,11 +148,17 @@
" # same order of Y_hat_df to prevent errors\n",
" S_ = S.loc[uids]\n",
" common_vals = dict(\n",
" y_insample = Y_df.pivot(columns='ds', values='y').loc[uids].values.astype(np.float32),\n",
" S = S_.values.astype(np.float32),\n",
" idx_bottom = S_.index.get_indexer(S.columns),\n",
" S=S_.values.astype(np.float32),\n",
" idx_bottom=S_.index.get_indexer(S.columns),\n",
" tags={key: S_.index.get_indexer(val) for key, val in tags.items()}\n",
" )\n",
" # we need insample values if \n",
" # we are using a method that requires them\n",
" # or if we are performing boostrap\n",
" if self.insample or bootstrap:\n",
" if Y_df is None:\n",
" raise Exception('you need to pass `Y_df`')\n",
" common_vals['y_insample'] = Y_df.pivot(columns='ds', values='y').loc[uids].values.astype(np.float32)\n",
" fcsts = Y_hat_df.copy()\n",
" for reconcile_fn in self.reconcilers:\n",
" reconcile_fn_name = _build_fn_name(reconcile_fn)\n",
Expand All @@ -177,7 +184,7 @@
" sigmah = sign * (y_hat_model - sigmah) / z\n",
" common_vals['sigmah'] = sigmah\n",
" common_vals['level'] = level\n",
" if has_fitted or bootstrap:\n",
" if (self.insample and has_fitted) or bootstrap:\n",
" if model_name in Y_df:\n",
" y_hat_insample = Y_df.pivot(columns='ds', values=model_name).loc[uids].values\n",
" y_hat_insample = y_hat_insample.astype(np.float32)\n",
Expand Down Expand Up @@ -211,7 +218,7 @@
" else:\n",
" del common_vals['bootstrap_samples']\n",
" del common_vals['bootstrap']\n",
" if has_fitted:\n",
" if self.insample and has_fitted:\n",
" del common_vals['y_hat_insample']\n",
" return fcsts"
]
Expand Down Expand Up @@ -245,72 +252,6 @@
" name='reconcile', title_level=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# <span style=\"color:DarkBlue\"> Example </span>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from statsforecast.core import StatsForecast\n",
"from statsforecast.models import ETS, Naive\n",
"\n",
"from hierarchicalforecast.utils import aggregate\n",
"from hierarchicalforecast.core import HierarchicalReconciliation\n",
"from hierarchicalforecast.methods import BottomUp, MinTrace\n",
"\n",
"# Load TourismSmall dataset\n",
"df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv')\n",
"df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1)\n",
"df.insert(0, 'Country', 'Australia')\n",
"\n",
"# Create hierarchical seires based on geographic levels and purpose\n",
"# And Convert quarterly ds string to pd.datetime format\n",
"hierarchy_levels = [['Country'],\n",
" ['Country', 'State'], \n",
" ['Country', 'Purpose'], \n",
" ['Country', 'State', 'Region'], \n",
" ['Country', 'State', 'Purpose'], \n",
" ['Country', 'State', 'Region', 'Purpose']]\n",
"\n",
"Y_df, S, tags = aggregate(df=df, spec=hierarchy_levels)\n",
"qs = Y_df['ds'].str.replace(r'(\\d+) (Q\\d)', r'\\1-\\2', regex=True)\n",
"Y_df['ds'] = pd.PeriodIndex(qs, freq='Q').to_timestamp()\n",
"Y_df = Y_df.reset_index()\n",
"\n",
"# Split train/test sets\n",
"Y_test_df = Y_df.groupby('unique_id').tail(4)\n",
"Y_train_df = Y_df.drop(Y_test_df.index)\n",
"\n",
"# Compute base auto-ETS predictions\n",
"# Careful identifying correct data freq, this data quarterly 'Q'\n",
"fcst = StatsForecast(df=Y_train_df,\n",
" #models=[ETS(season_length=12), Naive()],\n",
" models=[Naive()],\n",
" freq='Q', n_jobs=-1) \n",
"Y_hat_df = fcst.forecast(h=4)\n",
"\n",
"# Reconcile the base predictions\n",
"Y_train_df = Y_train_df.reset_index().set_index('unique_id')\n",
"Y_hat_df = Y_hat_df.reset_index().set_index('unique_id')\n",
"reconcilers = [BottomUp(),\n",
" MinTrace(method='ols')]\n",
"hrec = HierarchicalReconciliation(reconcilers=reconcilers)\n",
"Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,\n",
" S=S, tags=tags)\n",
"Y_rec_df.groupby('unique_id').head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -382,7 +323,32 @@
" # ERM recovers but needs bigger eps\n",
" #ERM(method='reg_bu', lambda_reg=None),\n",
"])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, hier_grouped_df, S_grouped, tags_grouped)\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_df_h, Y_df=hier_grouped_df, \n",
" S=S_grouped, tags=tags_grouped)\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" if 'ERM' in model:\n",
" eps = 3\n",
" else:\n",
" eps = 1e-5\n",
" test_close(reconciled['y'], reconciled[model], eps=eps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test reconcile method without insample\n",
"hrec = HierarchicalReconciliation(reconcilers=[\n",
" #these methods should reconstruct the original y\n",
" BottomUp(),\n",
" MinTrace(method='ols'),\n",
" MinTrace(method='wls_struct'),\n",
"])\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_df_h,\n",
" S=S_grouped, tags=tags_grouped)\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" if 'ERM' in model:\n",
" eps = 3\n",
Expand All @@ -404,7 +370,7 @@
"test_fail(\n",
" hrec.reconcile,\n",
" contains='requires strictly hierarchical structures',\n",
" args=(hier_grouped_df_h, hier_grouped_df, S_grouped, tags_grouped)\n",
" args=(hier_grouped_df_h, S_grouped, tags_grouped, hier_grouped_df,)\n",
")"
]
},
Expand Down Expand Up @@ -448,7 +414,12 @@
" # ERM recovers but needs bigger eps\n",
" #ERM(method='reg_bu', lambda_reg=None),\n",
"])\n",
"reconciled = hrec.reconcile(hier_strict_df_h, hier_strict_df, S_strict, tags_strict)\n",
"reconciled = hrec.reconcile(\n",
" Y_hat_df=hier_strict_df_h, \n",
" Y_df=hier_strict_df, \n",
" S=S_strict, \n",
" tags=tags_strict\n",
")\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" if 'ERM' in model:\n",
" eps = 3\n",
Expand Down Expand Up @@ -495,7 +466,12 @@
"#even if their signature includes\n",
"#that argument\n",
"hrec = HierarchicalReconciliation([MinTrace(method='ols')])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, hier_grouped_df.drop(columns=['y_model']), S_grouped, tags_grouped)\n",
"reconciled = hrec.reconcile(\n",
" Y_hat_df=hier_grouped_df_h, \n",
" Y_df=hier_grouped_df.drop(columns=['y_model']), \n",
" S=S_grouped, \n",
" tags=tags_grouped\n",
")\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" test_close(reconciled['y'], reconciled[model])"
]
Expand All @@ -521,14 +497,80 @@
"#intervals\n",
"hrec = HierarchicalReconciliation([BottomUp()])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, \n",
" hier_grouped_df, S_grouped, tags_grouped,\n",
" Y_df=hier_grouped_df, S=S_grouped, tags=tags_grouped,\n",
" level=[80, 90], bootstrap=True)\n",
"total = reconciled.loc[tags_grouped['Country/State/Region/Purpose']].groupby('ds').sum().reset_index()\n",
"pd.testing.assert_frame_equal(\n",
" total[['ds', 'y_model/BottomUp']],\n",
" reconciled.loc['Australia'][['ds', 'y_model/BottomUp']].reset_index(drop=True)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# <span style=\"color:DarkBlue\"> Example </span>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from statsforecast.core import StatsForecast\n",
"from statsforecast.models import ETS, Naive\n",
"\n",
"from hierarchicalforecast.utils import aggregate\n",
"from hierarchicalforecast.core import HierarchicalReconciliation\n",
"from hierarchicalforecast.methods import BottomUp, MinTrace\n",
"\n",
"# Load TourismSmall dataset\n",
"df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv')\n",
"df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1)\n",
"df.insert(0, 'Country', 'Australia')\n",
"\n",
"# Create hierarchical seires based on geographic levels and purpose\n",
"# And Convert quarterly ds string to pd.datetime format\n",
"hierarchy_levels = [['Country'],\n",
" ['Country', 'State'], \n",
" ['Country', 'Purpose'], \n",
" ['Country', 'State', 'Region'], \n",
" ['Country', 'State', 'Purpose'], \n",
" ['Country', 'State', 'Region', 'Purpose']]\n",
"\n",
"Y_df, S, tags = aggregate(df=df, spec=hierarchy_levels)\n",
"qs = Y_df['ds'].str.replace(r'(\\d+) (Q\\d)', r'\\1-\\2', regex=True)\n",
"Y_df['ds'] = pd.PeriodIndex(qs, freq='Q').to_timestamp()\n",
"Y_df = Y_df.reset_index()\n",
"\n",
"# Split train/test sets\n",
"Y_test_df = Y_df.groupby('unique_id').tail(4)\n",
"Y_train_df = Y_df.drop(Y_test_df.index)\n",
"\n",
"# Compute base auto-ETS predictions\n",
"# Careful identifying correct data freq, this data quarterly 'Q'\n",
"fcst = StatsForecast(df=Y_train_df,\n",
" #models=[ETS(season_length=12), Naive()],\n",
" models=[Naive()],\n",
" freq='Q', n_jobs=-1) \n",
"Y_hat_df = fcst.forecast(h=4)\n",
"\n",
"# Reconcile the base predictions\n",
"Y_train_df = Y_train_df.reset_index().set_index('unique_id')\n",
"Y_hat_df = Y_hat_df.reset_index().set_index('unique_id')\n",
"reconcilers = [BottomUp(),\n",
" MinTrace(method='ols')]\n",
"hrec = HierarchicalReconciliation(reconcilers=reconcilers)\n",
"Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,\n",
" S=S, tags=tags)\n",
"Y_rec_df.groupby('unique_id').head(2)"
]
}
],
"metadata": {
Expand Down
3 changes: 2 additions & 1 deletion nbs/evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@
" # ERM recovers but needs bigger eps\n",
" ERM(method='reg_bu', lambda_reg=None),\n",
"])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, hier_grouped_df, S_grouped, tags_grouped)"
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_df_h, Y_df=hier_grouped_df, \n",
" S=S_grouped, tags=tags_grouped)"
]
},
{
Expand Down

0 comments on commit 4d997be

Please sign in to comment.