diff --git a/hierarchicalforecast/core.py b/hierarchicalforecast/core.py index 9f406235..96544a3d 100644 --- a/hierarchicalforecast/core.py +++ b/hierarchicalforecast/core.py @@ -134,9 +134,18 @@ def _prepare_fit(self, # Declare output names drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds'] model_names = Y_hat_df.drop(columns=drop_cols, axis=1).columns.to_list() + + # Ensure numeric columns + if not len(Y_hat_df[model_names].select_dtypes(include='number').columns) == len(Y_hat_df[model_names].columns): + raise Exception('`Y_hat_df`s columns contain non numeric types') + + #Ensure no null values + if Y_hat_df[model_names].isnull().values.any(): + raise Exception('`Y_hat_df` contains null values') + pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)] model_names = [name for name in model_names if name not in pi_model_names] - + # TODO: Complete y_hat_insample protection if intervals_method in ['bootstrap', 'permbu']: if not (set(model_names) <= set(Y_df.columns)): diff --git a/hierarchicalforecast/utils.py b/hierarchicalforecast/utils.py index 971be2fc..ed51090d 100644 --- a/hierarchicalforecast/utils.py +++ b/hierarchicalforecast/utils.py @@ -202,6 +202,11 @@ def aggregate(df: pd.DataFrame, `Y_df, S_df, tags`: tuple with hierarchically structured series `Y_df` ($\mathbf{y}_{[a,b]}$), summing dataframe `S_df`, and hierarchical aggregation indexes `tags`. """ + + #Ensure no null values + if df.isnull().values.any(): + raise Exception('`df` contains null values') + #-------------------------------- Wrangling --------------------------------# # constraints S_df and collapsed Y_bottom_df with 'unique_id' Y_bottom_df, S_df, tags = _to_summing_dataframe(df=df, spec=spec) @@ -242,7 +247,7 @@ def aggregate(df: pd.DataFrame, Y_df = Y_df.set_index('unique_id') return Y_df, S_df, tags -# %% ../nbs/utils.ipynb 15 +# %% ../nbs/utils.ipynb 16 class HierarchicalPlot: """ Hierarchical Plot diff --git a/nbs/core.ipynb b/nbs/core.ipynb index b5c26791..018d466d 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -234,9 +234,18 @@ " # Declare output names\n", " drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']\n", " model_names = Y_hat_df.drop(columns=drop_cols, axis=1).columns.to_list()\n", + "\n", + " # Ensure numeric columns\n", + " if not len(Y_hat_df[model_names].select_dtypes(include='number').columns) == len(Y_hat_df[model_names].columns):\n", + " raise Exception('`Y_hat_df`s columns contain non numeric types')\n", + " \n", + " #Ensure no null values\n", + " if Y_hat_df[model_names].isnull().values.any():\n", + " raise Exception('`Y_hat_df` contains null values')\n", + " \n", " pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]\n", " model_names = [name for name in model_names if name not in pi_model_names]\n", - "\n", + " \n", " # TODO: Complete y_hat_insample protection\n", " if intervals_method in ['bootstrap', 'permbu']:\n", " if not (set(model_names) <= set(Y_df.columns)):\n", @@ -585,6 +594,39 @@ " test_close(reconciled['y'], reconciled[model], eps=eps)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test incorrect Y_hat_df datatypes\n", + "hier_grouped_hat_df_nan = hier_grouped_hat_df.copy()\n", + "hier_grouped_hat_df_nan.loc['Australia', 'y_model'] = float('nan')\n", + "test_fail(\n", + " hrec.reconcile,\n", + " contains='null values',\n", + " args=(hier_grouped_hat_df_nan, S_grouped_df, tags_grouped, hier_grouped_df),\n", + ")\n", + "\n", + "hier_grouped_hat_df_none = hier_grouped_hat_df.copy()\n", + "hier_grouped_hat_df_none.loc['Australia', 'y_model'] = None\n", + "test_fail(\n", + " hrec.reconcile,\n", + " contains='null values',\n", + " args=(hier_grouped_hat_df_none, S_grouped_df, tags_grouped, hier_grouped_df),\n", + ")\n", + "\n", + "hier_grouped_hat_df_str = hier_grouped_hat_df.copy()\n", + "hier_grouped_hat_df_str['y_model'] = hier_grouped_hat_df_str['y_model'].astype(str)\n", + "test_fail(\n", + " hrec.reconcile,\n", + " contains='numeric types',\n", + " args=(hier_grouped_hat_df_str, S_grouped_df, tags_grouped, hier_grouped_df),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1006,13 +1048,6 @@ " S=S_df, tags=tags)\n", "Y_rec_df.groupby('unique_id').head(2)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index 832c6c76..de04f2cb 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -53,7 +53,7 @@ "source": [ "#| hide\n", "from nbdev.showdoc import add_docs, show_doc\n", - "from fastcore.test import test_eq, test_close" + "from fastcore.test import test_eq, test_close, test_fail" ] }, { @@ -297,6 +297,11 @@ " `Y_df, S_df, tags`: tuple with hierarchically structured series `Y_df` ($\\mathbf{y}_{[a,b]}$),\n", " summing dataframe `S_df`, and hierarchical aggregation indexes `tags`.\n", " \"\"\"\n", + " \n", + " #Ensure no null values\n", + " if df.isnull().values.any():\n", + " raise Exception('`df` contains null values')\n", + " \n", " #-------------------------------- Wrangling --------------------------------#\n", " # constraints S_df and collapsed Y_bottom_df with 'unique_id'\n", " Y_bottom_df, S_df, tags = _to_summing_dataframe(df=df, spec=spec)\n", @@ -390,6 +395,37 @@ "test_eq(len(tags), len(hiers_grouped))" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e1cb923", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv')\n", + "df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1)\n", + "df.insert(0, 'Country', 'Australia')\n", + "\n", + "#Unit Test NaN Values\n", + "df_nan = df.copy()\n", + "df_nan.loc[0, 'Region'] = float('nan')\n", + "test_fail(\n", + " aggregate,\n", + " contains='null values',\n", + " args=(df_nan, hiers_strictly),\n", + ")\n", + "\n", + "#Unit Test None Values\n", + "df_none = df.copy()\n", + "df_none.loc[0, 'Region'] = None\n", + "test_fail(\n", + " aggregate,\n", + " contains='null values',\n", + " args=(df_none, hiers_strictly),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null,