From fe81f52f2c837ebbd52f267009fa861e2398c3d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 8 Nov 2024 11:21:49 -0600 Subject: [PATCH 1/5] initial work --- nbs/lag_transforms.ipynb | 72 ++++++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index 9caac5ee..d6e8c89a 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -113,7 +113,15 @@ " out._core_tfm = transforms[0]._core_tfm.stack(\n", " [tfm._core_tfm for tfm in transforms]\n", " )\n", - " return out" + " return out\n", + "\n", + " @property\n", + " def _lag(self):\n", + " return self._core_tfm.lag\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return -1" ] }, { @@ -125,6 +133,7 @@ "source": [ "#| exporti\n", "class Lag(_BaseLagTransform):\n", + " \n", " def __init__(self, lag: int):\n", " self.lag = lag\n", " self._core_tfm = core_tfms.Lag(lag=lag)\n", @@ -136,7 +145,11 @@ " return f'lag{lag}'\n", "\n", " def __eq__(self, other):\n", - " return isinstance(other, Lag) and self.lag == other.lag" + " return isinstance(other, Lag) and self.lag == other.lag\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return self.lag" ] }, { @@ -160,7 +173,11 @@ " If `None`, will be set to `window_size`.\n", " \"\"\"\n", " self.window_size = window_size\n", - " self.min_samples = min_samples" + " self.min_samples = min_samples\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return self._lag + self.window_size" ] }, { @@ -248,7 +265,11 @@ " \"\"\" \n", " self.season_length = season_length\n", " self.window_size = window_size\n", - " self.min_samples = min_samples" + " self.min_samples = min_samples\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return self._lag + self.window_size * self.season_length" ] }, { @@ -311,7 +332,11 @@ "class _ExpandingBase(_BaseLagTransform):\n", " \"\"\"Expanding statistic\"\"\"\n", " def __init__(self):\n", - " ..." + " ...\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return 1" ] }, { @@ -336,7 +361,11 @@ "\n", "class ExpandingQuantile(_ExpandingBase):\n", " def __init__(self, p: float):\n", - " self.p = p" + " self.p = p\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return -1" ] }, { @@ -378,7 +407,11 @@ " alpha : float\n", " Smoothing factor.\"\"\"\n", " def __init__(self, alpha: float):\n", - " self.alpha = alpha" + " self.alpha = alpha\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return 1" ] }, { @@ -429,8 +462,13 @@ " return self.tfm._get_name(lag + self.n)\n", "\n", " def _set_core_tfm(self, lag: int) -> 'Offset':\n", - " self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n)\n", - " return self" + " self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n)\n", + " self._core_tfm = self.tfm._core_tfm\n", + " return self\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return self.tfm.required_samples + self.n" ] }, { @@ -448,6 +486,16 @@ "np.testing.assert_allclose(transformed, expected)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "4566e803-0790-413d-8d0d-935039778515", + "metadata": {}, + "outputs": [], + "source": [ + "offset = Offset(RollingMean(window_size=10), 2)._set_core_tfm(5)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -488,7 +536,11 @@ " return self.operator(self.tfm1.transform(ga), self.tfm2.transform(ga))\n", "\n", " def update(self, ga: CoreGroupedArray) -> np.ndarray:\n", - " return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))" + " return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))\n", + "\n", + " @property\n", + " def required_samples(self):\n", + " return max(self.tfm1.required_samples, self.tfm2.required_samples)" ] }, { From f09e570d3248cc29bbcaa5151155fa7aa8c40c9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 8 Nov 2024 16:34:54 -0600 Subject: [PATCH 2/5] add core logic and test --- mlforecast/_modidx.py | 20 +++++ mlforecast/core.py | 10 +++ mlforecast/lag_transforms.py | 44 ++++++++++- nbs/core.ipynb | 147 +++++++++++++++++++++++++++++++++-- nbs/lag_transforms.ipynb | 41 ++++------ 5 files changed, 229 insertions(+), 33 deletions(-) diff --git a/mlforecast/_modidx.py b/mlforecast/_modidx.py index ac93701f..29d79bc6 100644 --- a/mlforecast/_modidx.py +++ b/mlforecast/_modidx.py @@ -262,6 +262,8 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Combine.update': ( 'lag_transforms.html#combine.update', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms.Combine.update_samples': ( 'lag_transforms.html#combine.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExpandingMax': ( 'lag_transforms.html#expandingmax', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExpandingMean': ( 'lag_transforms.html#expandingmean', @@ -272,12 +274,16 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExpandingQuantile.__init__': ( 'lag_transforms.html#expandingquantile.__init__', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms.ExpandingQuantile.update_samples': ( 'lag_transforms.html#expandingquantile.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExpandingStd': ( 'lag_transforms.html#expandingstd', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExponentiallyWeightedMean': ( 'lag_transforms.html#exponentiallyweightedmean', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.ExponentiallyWeightedMean.__init__': ( 'lag_transforms.html#exponentiallyweightedmean.__init__', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms.ExponentiallyWeightedMean.update_samples': ( 'lag_transforms.html#exponentiallyweightedmean.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Lag': ('lag_transforms.html#lag', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Lag.__eq__': ( 'lag_transforms.html#lag.__eq__', 'mlforecast/lag_transforms.py'), @@ -287,6 +293,8 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Lag._set_core_tfm': ( 'lag_transforms.html#lag._set_core_tfm', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms.Lag.update_samples': ( 'lag_transforms.html#lag.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Offset': ( 'lag_transforms.html#offset', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Offset.__init__': ( 'lag_transforms.html#offset.__init__', @@ -295,6 +303,8 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.Offset._set_core_tfm': ( 'lag_transforms.html#offset._set_core_tfm', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms.Offset.update_samples': ( 'lag_transforms.html#offset.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.RollingMax': ( 'lag_transforms.html#rollingmax', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms.RollingMean': ( 'lag_transforms.html#rollingmean', @@ -327,6 +337,8 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._BaseLagTransform._get_name': ( 'lag_transforms.html#_baselagtransform._get_name', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms._BaseLagTransform._lag': ( 'lag_transforms.html#_baselagtransform._lag', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._BaseLagTransform._set_core_tfm': ( 'lag_transforms.html#_baselagtransform._set_core_tfm', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._BaseLagTransform.stack': ( 'lag_transforms.html#_baselagtransform.stack', @@ -337,18 +349,26 @@ 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._BaseLagTransform.update': ( 'lag_transforms.html#_baselagtransform.update', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms._BaseLagTransform.update_samples': ( 'lag_transforms.html#_baselagtransform.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._ExpandingBase': ( 'lag_transforms.html#_expandingbase', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._ExpandingBase.__init__': ( 'lag_transforms.html#_expandingbase.__init__', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms._ExpandingBase.update_samples': ( 'lag_transforms.html#_expandingbase.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._RollingBase': ( 'lag_transforms.html#_rollingbase', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._RollingBase.__init__': ( 'lag_transforms.html#_rollingbase.__init__', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms._RollingBase.update_samples': ( 'lag_transforms.html#_rollingbase.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._Seasonal_RollingBase': ( 'lag_transforms.html#_seasonal_rollingbase', 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._Seasonal_RollingBase.__init__': ( 'lag_transforms.html#_seasonal_rollingbase.__init__', 'mlforecast/lag_transforms.py'), + 'mlforecast.lag_transforms._Seasonal_RollingBase.update_samples': ( 'lag_transforms.html#_seasonal_rollingbase.update_samples', + 'mlforecast/lag_transforms.py'), 'mlforecast.lag_transforms._pascal2camel': ( 'lag_transforms.html#_pascal2camel', 'mlforecast/lag_transforms.py')}, 'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'), diff --git a/mlforecast/core.py b/mlforecast/core.py index 6726d1bb..2096e499 100644 --- a/mlforecast/core.py +++ b/mlforecast/core.py @@ -432,6 +432,16 @@ def _transform( self._dropped_series = None # once we've computed the features and target we can slice the series + update_samples = [ + getattr(tfm, "update_samples", -1) for tfm in self.transforms.values() + ] + if ( + self.keep_last_n is None + and update_samples + and all(samples > -1 for samples in update_samples) + ): + # user didn't set keep_last_n and we can infer it from the transforms + self.keep_last_n = max(update_samples) if self.keep_last_n is not None: self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None)) del self._restore_idxs, self._sort_idxs diff --git a/mlforecast/lag_transforms.py b/mlforecast/lag_transforms.py index e1b7cb8b..72416e29 100644 --- a/mlforecast/lag_transforms.py +++ b/mlforecast/lag_transforms.py @@ -68,8 +68,17 @@ def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform": ) return out + @property + def _lag(self): + return self._core_tfm.lag - 1 + + @property + def update_samples(self) -> int: + return -1 + # %% ../nbs/lag_transforms.ipynb 6 class Lag(_BaseLagTransform): + def __init__(self, lag: int): self.lag = lag self._core_tfm = core_tfms.Lag(lag=lag) @@ -83,6 +92,10 @@ def _get_name(self, lag: int) -> str: def __eq__(self, other): return isinstance(other, Lag) and self.lag == other.lag + @property + def update_samples(self) -> int: + return self._lag + # %% ../nbs/lag_transforms.ipynb 7 class _RollingBase(_BaseLagTransform): "Rolling statistic" @@ -100,6 +113,10 @@ def __init__(self, window_size: int, min_samples: Optional[int] = None): self.window_size = window_size self.min_samples = min_samples + @property + def update_samples(self) -> int: + return self._lag + self.window_size + # %% ../nbs/lag_transforms.ipynb 8 class RollingMean(_RollingBase): ... @@ -149,6 +166,10 @@ def __init__( self.window_size = window_size self.min_samples = min_samples + @property + def update_samples(self) -> int: + return self._lag + self.season_length * self.window_size + # %% ../nbs/lag_transforms.ipynb 11 class SeasonalRollingMean(_Seasonal_RollingBase): ... @@ -183,6 +204,10 @@ class _ExpandingBase(_BaseLagTransform): def __init__(self): ... + @property + def update_samples(self) -> int: + return 1 + # %% ../nbs/lag_transforms.ipynb 14 class ExpandingMean(_ExpandingBase): ... @@ -200,6 +225,10 @@ class ExpandingQuantile(_ExpandingBase): def __init__(self, p: float): self.p = p + @property + def update_samples(self) -> int: + return -1 + # %% ../nbs/lag_transforms.ipynb 16 class ExponentiallyWeightedMean(_BaseLagTransform): """Exponentially weighted average @@ -212,6 +241,10 @@ class ExponentiallyWeightedMean(_BaseLagTransform): def __init__(self, alpha: float): self.alpha = alpha + @property + def update_samples(self) -> int: + return 1 + # %% ../nbs/lag_transforms.ipynb 18 class Offset(_BaseLagTransform): """Shift series before computing transformation @@ -231,9 +264,14 @@ def _get_name(self, lag: int) -> str: return self.tfm._get_name(lag + self.n) def _set_core_tfm(self, lag: int) -> "Offset": - self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n) + self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n) + self._core_tfm = self.tfm._core_tfm return self + @property + def update_samples(self) -> int: + return self.tfm.update_samples + self.n + # %% ../nbs/lag_transforms.ipynb 20 class Combine(_BaseLagTransform): """Combine two lag transformations using an operator @@ -269,3 +307,7 @@ def transform(self, ga: CoreGroupedArray) -> np.ndarray: def update(self, ga: CoreGroupedArray) -> np.ndarray: return self.operator(self.tfm1.update(ga), self.tfm2.update(ga)) + + @property + def update_samples(self): + return max(self.tfm1.update_samples, self.tfm2.update_samples) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 1cb2fca6..0d5451d2 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -20,7 +20,16 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "#|hide\n", "%load_ext autoreload\n", @@ -916,8 +925,18 @@ " self._dropped_series = None\n", "\n", " # once we've computed the features and target we can slice the series\n", + " update_samples = [\n", + " getattr(tfm, 'update_samples', -1) for tfm in self.transforms.values()\n", + " ]\n", + " if (\n", + " self.keep_last_n is None\n", + " and update_samples\n", + " and all(samples > -1 for samples in update_samples)\n", + " ):\n", + " # user didn't set keep_last_n and we can infer it from the transforms\n", + " self.keep_last_n = max(update_samples)\n", " if self.keep_last_n is not None:\n", - " self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None)) \n", + " self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))\n", " del self._restore_idxs, self._sort_idxs\n", "\n", " # lag transforms\n", @@ -1671,7 +1690,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.fit_transform\n", "\n", @@ -1692,7 +1711,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.fit_transform\n", "\n", @@ -1978,7 +1997,43 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "## TimeSeries.predict\n", + "\n", + "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n", + "> [sklearn.base.BaseEstimator]]], horizon:int,\n", + "> before_predict_callback:Optional[Callable]=None,\n", + "> after_predict_callback:Optional[Callable]=None,\n", + "> X_df:Optional[~DFType]=None,\n", + "> ids:Optional[List[str]]=None)" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "## TimeSeries.predict\n", + "\n", + "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n", + "> [sklearn.base.BaseEstimator]]], horizon:int,\n", + "> before_predict_callback:Optional[Callable]=None,\n", + "> after_predict_callback:Optional[Callable]=None,\n", + "> X_df:Optional[~DFType]=None,\n", + "> ids:Optional[List[str]]=None)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "show_doc(TimeSeries.predict, title_level=2)" ] @@ -2094,7 +2149,41 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "## TimeSeries.update\n", + "\n", + "> TimeSeries.update\n", + "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n", + "> .frame.DataFrame])\n", + "\n", + "*Update the values of the stored series.*" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "## TimeSeries.update\n", + "\n", + "> TimeSeries.update\n", + "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n", + "> .frame.DataFrame])\n", + "\n", + "*Update the values of the stored series.*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "show_doc(TimeSeries.update, title_level=2)" ] @@ -2180,7 +2269,15 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sys:1: CategoricalRemappingWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n" + ] + } + ], "source": [ "#| hide\n", "#| polars\n", @@ -2429,6 +2526,42 @@ "preds2 = ts2.predict({'model': NaiveModel()}, 10)\n", "pd.testing.assert_frame_equal(preds, preds2)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# automatically set keep_last_n for built-in lag transforms\n", + "ts = TimeSeries(\n", + " freq='D',\n", + " lags=[1, 2],\n", + " date_features=['dayofweek'],\n", + " lag_transforms={\n", + " 1: [RollingMean(1), RollingMean(4)],\n", + " },\n", + ")\n", + "ts.fit_transform(series, 'unique_id', 'ds', 'y', keep_last_n=20)\n", + "assert ts.keep_last_n == 20\n", + "ts.fit_transform(series, 'unique_id', 'ds', 'y')\n", + "assert ts.keep_last_n == 4\n", + "# we can't infer it for functions\n", + "ts = TimeSeries(\n", + " freq='D',\n", + " lags=[1, 2],\n", + " date_features=['dayofweek'],\n", + " lag_transforms={\n", + " 1: [RollingMean(1), RollingMean(4)],\n", + " 5: [expanding_mean],\n", + " },\n", + ")\n", + "ts.fit_transform(series, 'unique_id', 'ds', 'y', keep_last_n=20)\n", + "assert ts.keep_last_n == 20\n", + "ts.fit_transform(series, 'unique_id', 'ds', 'y')\n", + "assert ts.keep_last_n is None" + ] } ], "metadata": { diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index d6e8c89a..4f96cc81 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -117,10 +117,10 @@ "\n", " @property\n", " def _lag(self):\n", - " return self._core_tfm.lag\n", + " return self._core_tfm.lag - 1\n", "\n", " @property\n", - " def required_samples(self):\n", + " def update_samples(self) -> int:\n", " return -1" ] }, @@ -148,8 +148,8 @@ " return isinstance(other, Lag) and self.lag == other.lag\n", "\n", " @property\n", - " def required_samples(self):\n", - " return self.lag" + " def update_samples(self) -> int:\n", + " return self._lag" ] }, { @@ -176,7 +176,7 @@ " self.min_samples = min_samples\n", "\n", " @property\n", - " def required_samples(self):\n", + " def update_samples(self) -> int:\n", " return self._lag + self.window_size" ] }, @@ -268,8 +268,8 @@ " self.min_samples = min_samples\n", "\n", " @property\n", - " def required_samples(self):\n", - " return self._lag + self.window_size * self.season_length" + " def update_samples(self) -> int:\n", + " return self._lag + self.season_length * self.window_size" ] }, { @@ -335,7 +335,7 @@ " ...\n", "\n", " @property\n", - " def required_samples(self):\n", + " def update_samples(self) -> int:\n", " return 1" ] }, @@ -364,7 +364,7 @@ " self.p = p\n", "\n", " @property\n", - " def required_samples(self):\n", + " def update_samples(self) -> int:\n", " return -1" ] }, @@ -410,7 +410,7 @@ " self.alpha = alpha\n", "\n", " @property\n", - " def required_samples(self):\n", + " def update_samples(self) -> int:\n", " return 1" ] }, @@ -467,8 +467,8 @@ " return self\n", "\n", " @property\n", - " def required_samples(self):\n", - " return self.tfm.required_samples + self.n" + " def update_samples(self) -> int:\n", + " return self.tfm.update_samples + self.n" ] }, { @@ -486,16 +486,6 @@ "np.testing.assert_allclose(transformed, expected)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "4566e803-0790-413d-8d0d-935039778515", - "metadata": {}, - "outputs": [], - "source": [ - "offset = Offset(RollingMean(window_size=10), 2)._set_core_tfm(5)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -539,8 +529,8 @@ " return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))\n", "\n", " @property\n", - " def required_samples(self):\n", - " return max(self.tfm1.required_samples, self.tfm2.required_samples)" + " def update_samples(self):\n", + " return max(self.tfm1.update_samples, self.tfm2.update_samples)" ] }, { @@ -603,7 +593,8 @@ " tfm._set_core_tfm(1)\n", " tfm._get_name(1)\n", " tfm.transform(ga)\n", - " tfm.update(ga)" + " tfm.update(ga)\n", + " tfm.update_samples" ] } ], From 87e7ca1ab2501220b59cc07d54720e0e25995bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 11 Nov 2024 11:56:43 -0600 Subject: [PATCH 3/5] fix lag samples --- mlforecast/lag_transforms.py | 2 +- nbs/forecast.ipynb | 138 ++++++++++++++++------------------- nbs/lag_transforms.ipynb | 2 +- 3 files changed, 63 insertions(+), 79 deletions(-) diff --git a/mlforecast/lag_transforms.py b/mlforecast/lag_transforms.py index 72416e29..63cc7fb1 100644 --- a/mlforecast/lag_transforms.py +++ b/mlforecast/lag_transforms.py @@ -94,7 +94,7 @@ def __eq__(self, other): @property def update_samples(self) -> int: - return self._lag + return self.lag # %% ../nbs/lag_transforms.ipynb 7 class _RollingBase(_BaseLagTransform): diff --git a/nbs/forecast.ipynb b/nbs/forecast.ipynb index d9fc0699..56afd42e 100644 --- a/nbs/forecast.ipynb +++ b/nbs/forecast.ipynb @@ -1306,7 +1306,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L127){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast\n", "\n", @@ -1337,7 +1337,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L127){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast\n", "\n", @@ -1433,7 +1433,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L447){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L446){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit\n", "\n", @@ -1467,7 +1467,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L447){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L446){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit\n", "\n", @@ -1548,7 +1548,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L966){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L965){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.save\n", "\n", @@ -1564,7 +1564,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L966){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L965){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.save\n", "\n", @@ -1598,7 +1598,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L983){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L982){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.load\n", "\n", @@ -1614,7 +1614,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L983){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L982){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.load\n", "\n", @@ -1648,7 +1648,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1006){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1005){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.update\n", "\n", @@ -1666,7 +1666,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1006){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1005){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.update\n", "\n", @@ -1702,7 +1702,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L576){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L575){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.make_future_dataframe\n", "\n", @@ -1718,7 +1718,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L576){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L575){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.make_future_dataframe\n", "\n", @@ -1826,40 +1826,36 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L600){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L599){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.get_missing_future\n", "\n", - "> MLForecast.get_missing_future (h:int,\n", - "> X_df:Union[pandas.core.frame.DataFrame,pol\n", - "> ars.dataframe.frame.DataFrame])\n", + "> MLForecast.get_missing_future (h:int, X_df:~DFType)\n", "\n", "*Get the missing id and time combinations in `X_df`.*\n", "\n", "| | **Type** | **Details** |\n", "| -- | -------- | ----------- |\n", "| h | int | Number of periods to predict. |\n", - "| X_df | Union | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", - "| **Returns** | **Union** | **DataFrame with expected ids and future times missing in `X_df`** |" + "| X_df | DFType | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", + "| **Returns** | **DFType** | **DataFrame with expected ids and future times missing in `X_df`** |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L600){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L599){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.get_missing_future\n", "\n", - "> MLForecast.get_missing_future (h:int,\n", - "> X_df:Union[pandas.core.frame.DataFrame,pol\n", - "> ars.dataframe.frame.DataFrame])\n", + "> MLForecast.get_missing_future (h:int, X_df:~DFType)\n", "\n", "*Get the missing id and time combinations in `X_df`.*\n", "\n", "| | **Type** | **Details** |\n", "| -- | -------- | ----------- |\n", "| h | int | Number of periods to predict. |\n", - "| X_df | Union | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", - "| **Returns** | **Union** | **DataFrame with expected ids and future times missing in `X_df`** |" + "| X_df | DFType | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", + "| **Returns** | **DFType** | **DataFrame with expected ids and future times missing in `X_df`** |" ] }, "execution_count": null, @@ -1896,7 +1892,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L548){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L547){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.forecast_fitted_values\n", "\n", @@ -1914,7 +1910,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L548){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L547){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.forecast_fitted_values\n", "\n", @@ -2337,18 +2333,16 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L619){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L618){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.predict\n", "\n", "> MLForecast.predict (h:int,\n", "> before_predict_callback:Optional[Callable]=None,\n", - "> after_predict_callback:Optional[Callable]=None, new_d\n", - "> f:Union[pandas.core.frame.DataFrame,polars.dataframe.\n", - "> frame.DataFrame,NoneType]=None,\n", - "> level:Optional[List[Union[int,float]]]=None, X_df:Uni\n", - "> on[pandas.core.frame.DataFrame,polars.dataframe.frame\n", - "> .DataFrame,NoneType]=None,\n", + "> after_predict_callback:Optional[Callable]=None,\n", + "> new_df:Optional[~DFType]=None,\n", + "> level:Optional[List[Union[int,float]]]=None,\n", + "> X_df:Optional[~DFType]=None,\n", "> ids:Optional[List[str]]=None)\n", "\n", "*Compute the predictions for the next `h` steps.*\n", @@ -2358,27 +2352,25 @@ "| h | int | | Number of periods to predict. |\n", "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", - "| new_df | Union | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", + "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n", - "| X_df | Union | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", + "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", "| ids | Optional | None | List with subset of ids seen during training for which the forecasts should be computed. |\n", - "| **Returns** | **Union** | | **Predictions for each serie and timestep, with one column per model.** |" + "| **Returns** | **DFType** | | **Predictions for each serie and timestep, with one column per model.** |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L619){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L618){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.predict\n", "\n", "> MLForecast.predict (h:int,\n", "> before_predict_callback:Optional[Callable]=None,\n", - "> after_predict_callback:Optional[Callable]=None, new_d\n", - "> f:Union[pandas.core.frame.DataFrame,polars.dataframe.\n", - "> frame.DataFrame,NoneType]=None,\n", - "> level:Optional[List[Union[int,float]]]=None, X_df:Uni\n", - "> on[pandas.core.frame.DataFrame,polars.dataframe.frame\n", - "> .DataFrame,NoneType]=None,\n", + "> after_predict_callback:Optional[Callable]=None,\n", + "> new_df:Optional[~DFType]=None,\n", + "> level:Optional[List[Union[int,float]]]=None,\n", + "> X_df:Optional[~DFType]=None,\n", "> ids:Optional[List[str]]=None)\n", "\n", "*Compute the predictions for the next `h` steps.*\n", @@ -2388,11 +2380,11 @@ "| h | int | | Number of periods to predict. |\n", "| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n", "| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n", - "| new_df | Union | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", + "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n", "| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n", - "| X_df | Union | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", + "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n", "| ids | Optional | None | List with subset of ids seen during training for which the forecasts should be computed. |\n", - "| **Returns** | **Union** | | **Predictions for each serie and timestep, with one column per model.** |" + "| **Returns** | **DFType** | | **Predictions for each serie and timestep, with one column per model.** |" ] }, "execution_count": null, @@ -2947,13 +2939,11 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L206){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L205){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.preprocess\n", "\n", - "> MLForecast.preprocess\n", - "> (df:Union[pandas.core.frame.DataFrame,polars.dataf\n", - "> rame.frame.DataFrame], id_col:str='unique_id',\n", + "> MLForecast.preprocess (df:~DFType, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y',\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True, keep_last_n:Optional[int]=None,\n", @@ -2964,7 +2954,7 @@ "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| df | Union | | Series data in long format. |\n", + "| df | DFType | | Series data in long format. |\n", "| id_col | str | unique_id | Column that identifies each serie. |\n", "| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n", "| target_col | str | y | Column that contains the target. |\n", @@ -2979,13 +2969,11 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L206){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L205){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.preprocess\n", "\n", - "> MLForecast.preprocess\n", - "> (df:Union[pandas.core.frame.DataFrame,polars.dataf\n", - "> rame.frame.DataFrame], id_col:str='unique_id',\n", + "> MLForecast.preprocess (df:~DFType, id_col:str='unique_id',\n", "> time_col:str='ds', target_col:str='y',\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True, keep_last_n:Optional[int]=None,\n", @@ -2996,7 +2984,7 @@ "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| df | Union | | Series data in long format. |\n", + "| df | DFType | | Series data in long format. |\n", "| id_col | str | unique_id | Column that identifies each serie. |\n", "| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n", "| target_col | str | y | Column that contains the target. |\n", @@ -3299,7 +3287,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L262){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L261){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit_models\n", "\n", @@ -3318,7 +3306,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L262){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L261){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.fit_models\n", "\n", @@ -3443,15 +3431,13 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L767){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L766){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.cross_validation\n", "\n", - "> MLForecast.cross_validation\n", - "> (df:Union[pandas.core.frame.DataFrame,polars\n", - "> .dataframe.frame.DataFrame], n_windows:int,\n", - "> h:int, id_col:str='unique_id',\n", - "> time_col:str='ds', target_col:str='y',\n", + "> MLForecast.cross_validation (df:~DFType, n_windows:int, h:int,\n", + "> id_col:str='unique_id', time_col:str='ds',\n", + "> target_col:str='y',\n", "> step_size:Optional[int]=None,\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True,\n", @@ -3472,7 +3458,7 @@ "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| df | Union | | Series data in long format. |\n", + "| df | DFType | | Series data in long format. |\n", "| n_windows | int | | Number of windows to evaluate. |\n", "| h | int | | Forecast horizon. |\n", "| id_col | str | unique_id | Column that identifies each serie. |\n", @@ -3491,20 +3477,18 @@ "| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n", "| fitted | bool | False | Store the in-sample predictions. |\n", "| as_numpy | bool | False | Cast features to numpy array. |\n", - "| **Returns** | **Union** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |" + "| **Returns** | **DFType** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L767){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L766){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.cross_validation\n", "\n", - "> MLForecast.cross_validation\n", - "> (df:Union[pandas.core.frame.DataFrame,polars\n", - "> .dataframe.frame.DataFrame], n_windows:int,\n", - "> h:int, id_col:str='unique_id',\n", - "> time_col:str='ds', target_col:str='y',\n", + "> MLForecast.cross_validation (df:~DFType, n_windows:int, h:int,\n", + "> id_col:str='unique_id', time_col:str='ds',\n", + "> target_col:str='y',\n", "> step_size:Optional[int]=None,\n", "> static_features:Optional[List[str]]=None,\n", "> dropna:bool=True,\n", @@ -3525,7 +3509,7 @@ "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| df | Union | | Series data in long format. |\n", + "| df | DFType | | Series data in long format. |\n", "| n_windows | int | | Number of windows to evaluate. |\n", "| h | int | | Forecast horizon. |\n", "| id_col | str | unique_id | Column that identifies each serie. |\n", @@ -3544,7 +3528,7 @@ "| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n", "| fitted | bool | False | Store the in-sample predictions. |\n", "| as_numpy | bool | False | Cast features to numpy array. |\n", - "| **Returns** | **Union** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |" + "| **Returns** | **DFType** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |" ] }, "execution_count": null, @@ -4286,6 +4270,7 @@ " h=horizon,\n", " step_size=horizon,\n", " input_size=input_size,\n", + " keep_last_n=input_size,\n", ")\n", "series_lengths = np.diff(fcst.ts.ga.indptr)\n", "unique_lengths = np.unique(series_lengths)\n", @@ -4437,7 +4422,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.from_cv\n", "\n", @@ -4446,7 +4431,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MLForecast.from_cv\n", "\n", @@ -4480,7 +4465,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "[LightGBM] [Info] Start training from score 0.084340\n", "[10] mape: 0.118569\n", "[20] mape: 0.111506\n", "[30] mape: 0.107314\n", diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index 4f96cc81..3ff6ac59 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -149,7 +149,7 @@ "\n", " @property\n", " def update_samples(self) -> int:\n", - " return self._lag" + " return self.lag" ] }, { From 328c7b520488ce4dda49441fde3f0aa3fc8b98b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 11 Nov 2024 12:31:48 -0600 Subject: [PATCH 4/5] one more test --- nbs/lag_transforms.ipynb | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb index 3ff6ac59..2eb96f58 100644 --- a/nbs/lag_transforms.ipynb +++ b/nbs/lag_transforms.ipynb @@ -541,7 +541,9 @@ "outputs": [], "source": [ "#| hide\n", - "import operator" + "import operator\n", + "\n", + "from mlforecast.grouped_array import GroupedArray as MLGroupedArray" ] }, { @@ -593,8 +595,14 @@ " tfm._set_core_tfm(1)\n", " tfm._get_name(1)\n", " tfm.transform(ga)\n", - " tfm.update(ga)\n", - " tfm.update_samples" + " updates = tfm.update(ga)\n", + " upd_samples = tfm.update_samples\n", + " if upd_samples > -1:\n", + " sliced_ga = MLGroupedArray(ga.data, ga.indptr).take_from_groups(slice(-upd_samples, None))\n", + " ga2 = CoreGroupedArray(sliced_ga.data, sliced_ga.indptr)\n", + " tfm.transform(ga) # to reset state\n", + " updates2 = tfm.update(ga2)\n", + " np.testing.assert_allclose(updates, updates2)" ] } ], From 73675816894c8ad9eb4f498f9a6a25c5af5530f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 11 Nov 2024 12:33:32 -0600 Subject: [PATCH 5/5] enforce positive --- mlforecast/core.py | 2 +- nbs/core.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlforecast/core.py b/mlforecast/core.py index 2096e499..cfe9e6b2 100644 --- a/mlforecast/core.py +++ b/mlforecast/core.py @@ -438,7 +438,7 @@ def _transform( if ( self.keep_last_n is None and update_samples - and all(samples > -1 for samples in update_samples) + and all(samples > 0 for samples in update_samples) ): # user didn't set keep_last_n and we can infer it from the transforms self.keep_last_n = max(update_samples) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 0d5451d2..1fefe533 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -931,7 +931,7 @@ " if (\n", " self.keep_last_n is None\n", " and update_samples\n", - " and all(samples > -1 for samples in update_samples)\n", + " and all(samples > 0 for samples in update_samples)\n", " ):\n", " # user didn't set keep_last_n and we can infer it from the transforms\n", " self.keep_last_n = max(update_samples)\n",