Skip to content

Commit

Permalink
FEAT: Add finetune_depth parameter (#471)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Sprangers <45119856+elephaint@users.noreply.github.com>
Co-authored-by: Olivier Sprangers <osprangers@gmail.com>
Co-authored-by: José Morales <jmoralz92@gmail.com>
  • Loading branch information
4 people authored Oct 15, 2024
1 parent 5cf879c commit 0359bea
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 35 deletions.
57 changes: 31 additions & 26 deletions nbs/docs/capabilities/forecast/07_finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nixtla/nixtla/blob/main/nbs/docs/capabilities/forecast/07_finetuning.ipynb)"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"#| echo: false\n",
"if not IN_COLAB:\n",
Expand Down Expand Up @@ -124,18 +111,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nixtla.nixtla_client:Validating inputs...\n",
"INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
"INFO:nixtla.nixtla_client:Inferred freq: MS\n",
"INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n"
]
}
],
"outputs": [],
"source": [
"# Read data\n",
"df = pd.read_csv(\"https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv\")\n",
Expand Down Expand Up @@ -166,6 +142,35 @@
"> By default, `timegpt-1` is used. Please see [this tutorial](https://docs.nixtla.io/docs/tutorials-long_horizon_forecasting) on how and when to use `timegpt-1-long-horizon`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, only a small amount of finetuning is applied (`finetune_depth=1`). We can increase the intensity of finetuning by increasing the `finetune_depth` parameter. Note that increasing `finetune_depth` and `finetune_steps` increases wall time for generating predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read data\n",
"df = pd.read_csv(\"https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv\")\n",
"\n",
"# Forecast with fine-tuning.\n",
"# Here, we fine-tune for 5 steps\n",
"# and we finetune more than just the last layer\n",
"forecast_df = nixtla_client.forecast(\n",
" df=df,\n",
" h=12,\n",
" finetune_steps=5,\n",
" finetune_depth=2,\n",
" time_col='timestamp',\n",
" target_col=\"value\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
201 changes: 192 additions & 9 deletions nbs/docs/tutorials/06_finetuning.ipynb

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"#| exporti\n",
"_Loss = Literal[\"default\", \"mae\", \"mse\", \"rmse\", \"mape\", \"smape\"]\n",
"_Model = Literal[\"azureai\", \"timegpt-1\", \"timegpt-1-long-horizon\"]\n",
"_Finetune_Depth = Literal[1, 2, 3, 4, 5]\n",
"\n",
"_date_features_by_freq = {\n",
" # Daily frequencies\n",
Expand Down Expand Up @@ -925,6 +926,7 @@
" level: Optional[List[Union[int, float]]] = None,\n",
" quantiles: Optional[List[float]] = None,\n",
" finetune_steps: NonNegativeInt = 0,\n",
" finetune_depth: _Finetune_Depth = 1,\n",
" finetune_loss: _Loss = 'default',\n",
" clean_ex_first: bool = True,\n",
" validate_api_key: bool = False,\n",
Expand Down Expand Up @@ -976,6 +978,9 @@
" finetune_steps : int (default=0)\n",
" Number of steps used to finetune learning TimeGPT in the\n",
" new data.\n",
" finetune_depth: int (default=1)\n",
" The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,\n",
" and 5 means that the entire model is finetuned.\n",
" finetune_loss : str (default='default')\n",
" Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.\n",
" clean_ex_first : bool (default=True)\n",
Expand Down Expand Up @@ -1025,6 +1030,7 @@
" level=level,\n",
" quantiles=quantiles,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_api_key=validate_api_key,\n",
Expand Down Expand Up @@ -1112,6 +1118,7 @@
" 'clean_ex_first': clean_ex_first,\n",
" 'level': level,\n",
" 'finetune_steps': finetune_steps,\n",
" 'finetune_depth': finetune_depth,\n",
" 'finetune_loss': finetune_loss,\n",
" 'feature_contributions': feature_contributions and X is not None,\n",
" }\n",
Expand Down Expand Up @@ -1352,6 +1359,7 @@
" n_windows: PositiveInt = 1,\n",
" step_size: Optional[PositiveInt] = None,\n",
" finetune_steps: NonNegativeInt = 0,\n",
" finetune_depth: _Finetune_Depth = 1,\n",
" finetune_loss: str = 'default',\n",
" clean_ex_first: bool = True,\n",
" date_features: Union[bool, List[str]] = False,\n",
Expand Down Expand Up @@ -1404,6 +1412,9 @@
" finetune_steps : int (default=0)\n",
" Number of steps used to finetune TimeGPT in the\n",
" new data.\n",
" finetune_depth: int (default=1)\n",
" The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,\n",
" and 5 means that the entire model is finetuned.\n",
" finetune_loss : str (default='default')\n",
" Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.\n",
" clean_ex_first : bool (default=True)\n",
Expand Down Expand Up @@ -1447,6 +1458,7 @@
" step_size=step_size,\n",
" validate_api_key=validate_api_key,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -1531,6 +1543,7 @@
" 'clean_ex_first': clean_ex_first,\n",
" 'level': level,\n",
" 'finetune_steps': finetune_steps,\n",
" 'finetune_depth': finetune_depth,\n",
" 'finetune_loss': finetune_loss,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
Expand Down Expand Up @@ -2641,6 +2654,7 @@
" level: Optional[List[Union[int, float]]],\n",
" quantiles: Optional[List[float]],\n",
" finetune_steps: NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: _Loss,\n",
" clean_ex_first: bool,\n",
" validate_api_key: bool,\n",
Expand Down Expand Up @@ -2668,6 +2682,7 @@
" level=level,\n",
" quantiles=quantiles,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_api_key=validate_api_key,\n",
Expand Down Expand Up @@ -2723,6 +2738,7 @@
" n_windows: PositiveInt,\n",
" step_size: Optional[PositiveInt],\n",
" finetune_steps: NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: str,\n",
" clean_ex_first: bool,\n",
" date_features: Union[bool, List[str]],\n",
Expand All @@ -2743,6 +2759,7 @@
" n_windows=n_windows,\n",
" step_size=step_size,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -2829,6 +2846,7 @@
" level: Optional[List[Union[int, float]]],\n",
" quantiles: Optional[List[float]],\n",
" finetune_steps: NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: _Loss,\n",
" clean_ex_first: bool,\n",
" validate_api_key: bool,\n",
Expand Down Expand Up @@ -2884,6 +2902,7 @@
" level=level,\n",
" quantiles=quantiles,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_api_key=validate_api_key,\n",
Expand Down Expand Up @@ -2965,6 +2984,7 @@
" n_windows: PositiveInt,\n",
" step_size: Optional[PositiveInt],\n",
" finetune_steps: NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: _Loss,\n",
" clean_ex_first: bool,\n",
" date_features: Union[bool, List[Union[str, Callable]]],\n",
Expand Down Expand Up @@ -3001,6 +3021,7 @@
" n_windows=n_windows,\n",
" step_size=step_size,\n",
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" date_features=date_features,\n",
Expand Down
21 changes: 21 additions & 0 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
# %% ../nbs/src/nixtla_client.ipynb 7
_Loss = Literal["default", "mae", "mse", "rmse", "mape", "smape"]
_Model = Literal["azureai", "timegpt-1", "timegpt-1-long-horizon"]
_Finetune_Depth = Literal[1, 2, 3, 4, 5]

_date_features_by_freq = {
# Daily frequencies
Expand Down Expand Up @@ -855,6 +856,7 @@ def forecast(
level: Optional[List[Union[int, float]]] = None,
quantiles: Optional[List[float]] = None,
finetune_steps: NonNegativeInt = 0,
finetune_depth: _Finetune_Depth = 1,
finetune_loss: _Loss = "default",
clean_ex_first: bool = True,
validate_api_key: bool = False,
Expand Down Expand Up @@ -906,6 +908,9 @@ def forecast(
finetune_steps : int (default=0)
Number of steps used to finetune learning TimeGPT in the
new data.
finetune_depth: int (default=1)
The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,
and 5 means that the entire model is finetuned.
finetune_loss : str (default='default')
Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
clean_ex_first : bool (default=True)
Expand Down Expand Up @@ -955,6 +960,7 @@ def forecast(
level=level,
quantiles=quantiles,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
validate_api_key=validate_api_key,
Expand Down Expand Up @@ -1044,6 +1050,7 @@ def forecast(
"clean_ex_first": clean_ex_first,
"level": level,
"finetune_steps": finetune_steps,
"finetune_depth": finetune_depth,
"finetune_loss": finetune_loss,
"feature_contributions": feature_contributions and X is not None,
}
Expand Down Expand Up @@ -1290,6 +1297,7 @@ def cross_validation(
n_windows: PositiveInt = 1,
step_size: Optional[PositiveInt] = None,
finetune_steps: NonNegativeInt = 0,
finetune_depth: _Finetune_Depth = 1,
finetune_loss: str = "default",
clean_ex_first: bool = True,
date_features: Union[bool, List[str]] = False,
Expand Down Expand Up @@ -1342,6 +1350,9 @@ def cross_validation(
finetune_steps : int (default=0)
Number of steps used to finetune TimeGPT in the
new data.
finetune_depth: int (default=1)
The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,
and 5 means that the entire model is finetuned.
finetune_loss : str (default='default')
Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
clean_ex_first : bool (default=True)
Expand Down Expand Up @@ -1385,6 +1396,7 @@ def cross_validation(
step_size=step_size,
validate_api_key=validate_api_key,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
date_features=date_features,
Expand Down Expand Up @@ -1469,6 +1481,7 @@ def cross_validation(
"clean_ex_first": clean_ex_first,
"level": level,
"finetune_steps": finetune_steps,
"finetune_depth": finetune_depth,
"finetune_loss": finetune_loss,
}
with httpx.Client(**self._client_kwargs) as client:
Expand Down Expand Up @@ -1627,6 +1640,7 @@ def _forecast_wrapper(
level: Optional[List[Union[int, float]]],
quantiles: Optional[List[float]],
finetune_steps: NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: _Loss,
clean_ex_first: bool,
validate_api_key: bool,
Expand Down Expand Up @@ -1654,6 +1668,7 @@ def _forecast_wrapper(
level=level,
quantiles=quantiles,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
validate_api_key=validate_api_key,
Expand Down Expand Up @@ -1711,6 +1726,7 @@ def _cross_validation_wrapper(
n_windows: PositiveInt,
step_size: Optional[PositiveInt],
finetune_steps: NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: str,
clean_ex_first: bool,
date_features: Union[bool, List[str]],
Expand All @@ -1731,6 +1747,7 @@ def _cross_validation_wrapper(
n_windows=n_windows,
step_size=step_size,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
date_features=date_features,
Expand Down Expand Up @@ -1820,6 +1837,7 @@ def _distributed_forecast(
level: Optional[List[Union[int, float]]],
quantiles: Optional[List[float]],
finetune_steps: NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: _Loss,
clean_ex_first: bool,
validate_api_key: bool,
Expand Down Expand Up @@ -1876,6 +1894,7 @@ def format_X_df(
level=level,
quantiles=quantiles,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
validate_api_key=validate_api_key,
Expand Down Expand Up @@ -1959,6 +1978,7 @@ def _distributed_cross_validation(
n_windows: PositiveInt,
step_size: Optional[PositiveInt],
finetune_steps: NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: _Loss,
clean_ex_first: bool,
date_features: Union[bool, List[Union[str, Callable]]],
Expand Down Expand Up @@ -1995,6 +2015,7 @@ def _distributed_cross_validation(
n_windows=n_windows,
step_size=step_size,
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
clean_ex_first=clean_ex_first,
date_features=date_features,
Expand Down

0 comments on commit 0359bea

Please sign in to comment.