Skip to content

Commit

Permalink
fix: all transform_inputs to handle cases where index column is set a…
Browse files Browse the repository at this point in the history
…s time_col (#254)
  • Loading branch information
Yibei990826 authored Mar 21, 2024
1 parent 6ec87cf commit 77c2fbc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
60 changes: 42 additions & 18 deletions nbs/timegpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -246,39 +246,38 @@
"\n",
" def transform_inputs(self, df: pd.DataFrame, X_df: pd.DataFrame):\n",
" df = df.copy()\n",
" main_logger.info('Validating inputs...')\n",
" if self.base_freq is None and hasattr(df.index, 'freq'):\n",
" main_logger.info(\"Validating inputs...\")\n",
" if self.base_freq is None and hasattr(df.index, \"freq\"):\n",
" inferred_freq = df.index.freq\n",
" if inferred_freq is not None:\n",
" inferred_freq = inferred_freq.rule_code\n",
" main_logger.info(f'Inferred freq: {inferred_freq}')\n",
" main_logger.info(f\"Inferred freq: {inferred_freq}\")\n",
" self.freq = inferred_freq\n",
" time_col = df.index.name\n",
" if time_col is None:\n",
" time_col = 'ds'\n",
" df.index.name = time_col\n",
" time_col = df.index.name if df.index.name else \"ds\"\n",
" self.time_col = time_col\n",
" df.index.name = time_col\n",
" df = df.reset_index()\n",
" else:\n",
" self.freq = self.base_freq\n",
" renamer = {\n",
" self.id_col: 'unique_id',\n",
" self.time_col: 'ds',\n",
" self.target_col: 'y',\n",
" self.id_col: \"unique_id\",\n",
" self.time_col: \"ds\",\n",
" self.target_col: \"y\",\n",
" }\n",
" df = df.rename(columns=renamer)\n",
" if df.dtypes.ds != 'object':\n",
" df['ds'] = df['ds'].astype(str)\n",
" if 'unique_id' not in df.columns:\n",
" if df.dtypes.ds != \"object\":\n",
" df[\"ds\"] = df[\"ds\"].astype(str)\n",
" if \"unique_id\" not in df.columns:\n",
" # Insert unique_id column\n",
" df = df.assign(unique_id='ts_0')\n",
" df = df.assign(unique_id=\"ts_0\")\n",
" self.drop_uid = True\n",
" if X_df is not None:\n",
" X_df = X_df.copy()\n",
" X_df = X_df.rename(columns=renamer)\n",
" if 'unique_id' not in X_df.columns:\n",
" X_df = X_df.assign(unique_id='ts_0')\n",
" if X_df.dtypes.ds != 'object':\n",
" X_df['ds'] = X_df['ds'].astype(str)\n",
" if \"unique_id\" not in X_df.columns:\n",
" X_df = X_df.assign(unique_id=\"ts_0\")\n",
" if X_df.dtypes.ds != \"object\":\n",
" X_df[\"ds\"] = X_df[\"ds\"].astype(str)\n",
" return df, X_df\n",
"\n",
" def transform_outputs(self, fcst_df: pd.DataFrame, level_to_quantiles: bool = False):\n",
Expand Down Expand Up @@ -2445,6 +2444,31 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test using index as time_col\n",
"# same results\n",
"df_test = df.copy()\n",
"df_test[\"timestamp\"] = pd.to_datetime(df_test[\"timestamp\"])\n",
"df_test.set_index(df_test[\"timestamp\"], inplace=True)\n",
"df_test.drop(columns=\"timestamp\", inplace=True)\n",
"\n",
"# Using user_provided time_col and freq\n",
"timegpt_anomalies_df_1 = timegpt.detect_anomalies(df, time_col='timestamp', target_col='value', freq= 'M')\n",
"# Infer time_col and freq from index\n",
"timegpt_anomalies_df_2 = timegpt.detect_anomalies(df_test, time_col='timestamp', target_col='value')\n",
"\n",
"pd.testing.assert_frame_equal(\n",
" timegpt_anomalies_df_1,\n",
" timegpt_anomalies_df_2 \n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
7 changes: 3 additions & 4 deletions nixtlats/timegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ def transform_inputs(self, df: pd.DataFrame, X_df: pd.DataFrame):
inferred_freq = inferred_freq.rule_code
main_logger.info(f"Inferred freq: {inferred_freq}")
self.freq = inferred_freq
time_col = df.index.name
if time_col is None:
time_col = "ds"
df.index.name = time_col
time_col = df.index.name if df.index.name else "ds"
self.time_col = time_col
df.index.name = time_col
df = df.reset_index()
else:
self.freq = self.base_freq
Expand Down

0 comments on commit 77c2fbc

Please sign in to comment.