diff --git a/mlforecast/_modidx.py b/mlforecast/_modidx.py
index ac93701f..29d79bc6 100644
--- a/mlforecast/_modidx.py
+++ b/mlforecast/_modidx.py
@@ -262,6 +262,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Combine.update': ( 'lag_transforms.html#combine.update',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms.Combine.update_samples': ( 'lag_transforms.html#combine.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMax': ( 'lag_transforms.html#expandingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMean': ( 'lag_transforms.html#expandingmean',
@@ -272,12 +274,16 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingQuantile.__init__': ( 'lag_transforms.html#expandingquantile.__init__',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms.ExpandingQuantile.update_samples': ( 'lag_transforms.html#expandingquantile.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingStd': ( 'lag_transforms.html#expandingstd',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean': ( 'lag_transforms.html#exponentiallyweightedmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean.__init__': ( 'lag_transforms.html#exponentiallyweightedmean.__init__',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms.ExponentiallyWeightedMean.update_samples': ( 'lag_transforms.html#exponentiallyweightedmean.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag': ('lag_transforms.html#lag', 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag.__eq__': ( 'lag_transforms.html#lag.__eq__',
'mlforecast/lag_transforms.py'),
@@ -287,6 +293,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag._set_core_tfm': ( 'lag_transforms.html#lag._set_core_tfm',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms.Lag.update_samples': ( 'lag_transforms.html#lag.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset': ( 'lag_transforms.html#offset',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset.__init__': ( 'lag_transforms.html#offset.__init__',
@@ -295,6 +303,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset._set_core_tfm': ( 'lag_transforms.html#offset._set_core_tfm',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms.Offset.update_samples': ( 'lag_transforms.html#offset.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMax': ( 'lag_transforms.html#rollingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMean': ( 'lag_transforms.html#rollingmean',
@@ -327,6 +337,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform._get_name': ( 'lag_transforms.html#_baselagtransform._get_name',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms._BaseLagTransform._lag': ( 'lag_transforms.html#_baselagtransform._lag',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform._set_core_tfm': ( 'lag_transforms.html#_baselagtransform._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform.stack': ( 'lag_transforms.html#_baselagtransform.stack',
@@ -337,18 +349,26 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform.update': ( 'lag_transforms.html#_baselagtransform.update',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms._BaseLagTransform.update_samples': ( 'lag_transforms.html#_baselagtransform.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._ExpandingBase': ( 'lag_transforms.html#_expandingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._ExpandingBase.__init__': ( 'lag_transforms.html#_expandingbase.__init__',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms._ExpandingBase.update_samples': ( 'lag_transforms.html#_expandingbase.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._RollingBase': ( 'lag_transforms.html#_rollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._RollingBase.__init__': ( 'lag_transforms.html#_rollingbase.__init__',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms._RollingBase.update_samples': ( 'lag_transforms.html#_rollingbase.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._Seasonal_RollingBase': ( 'lag_transforms.html#_seasonal_rollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._Seasonal_RollingBase.__init__': ( 'lag_transforms.html#_seasonal_rollingbase.__init__',
'mlforecast/lag_transforms.py'),
+ 'mlforecast.lag_transforms._Seasonal_RollingBase.update_samples': ( 'lag_transforms.html#_seasonal_rollingbase.update_samples',
+ 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._pascal2camel': ( 'lag_transforms.html#_pascal2camel',
'mlforecast/lag_transforms.py')},
'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'),
diff --git a/mlforecast/core.py b/mlforecast/core.py
index 6726d1bb..cfe9e6b2 100644
--- a/mlforecast/core.py
+++ b/mlforecast/core.py
@@ -432,6 +432,16 @@ def _transform(
self._dropped_series = None
# once we've computed the features and target we can slice the series
+ update_samples = [
+ getattr(tfm, "update_samples", -1) for tfm in self.transforms.values()
+ ]
+ if (
+ self.keep_last_n is None
+ and update_samples
+ and all(samples > 0 for samples in update_samples)
+ ):
+ # user didn't set keep_last_n and we can infer it from the transforms
+ self.keep_last_n = max(update_samples)
if self.keep_last_n is not None:
self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))
del self._restore_idxs, self._sort_idxs
diff --git a/mlforecast/lag_transforms.py b/mlforecast/lag_transforms.py
index e1b7cb8b..63cc7fb1 100644
--- a/mlforecast/lag_transforms.py
+++ b/mlforecast/lag_transforms.py
@@ -68,8 +68,17 @@ def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform":
)
return out
+ @property
+ def _lag(self):
+ return self._core_tfm.lag - 1
+
+ @property
+ def update_samples(self) -> int:
+ return -1
+
# %% ../nbs/lag_transforms.ipynb 6
class Lag(_BaseLagTransform):
+
def __init__(self, lag: int):
self.lag = lag
self._core_tfm = core_tfms.Lag(lag=lag)
@@ -83,6 +92,10 @@ def _get_name(self, lag: int) -> str:
def __eq__(self, other):
return isinstance(other, Lag) and self.lag == other.lag
+ @property
+ def update_samples(self) -> int:
+ return self.lag
+
# %% ../nbs/lag_transforms.ipynb 7
class _RollingBase(_BaseLagTransform):
"Rolling statistic"
@@ -100,6 +113,10 @@ def __init__(self, window_size: int, min_samples: Optional[int] = None):
self.window_size = window_size
self.min_samples = min_samples
+ @property
+ def update_samples(self) -> int:
+ return self._lag + self.window_size
+
# %% ../nbs/lag_transforms.ipynb 8
class RollingMean(_RollingBase): ...
@@ -149,6 +166,10 @@ def __init__(
self.window_size = window_size
self.min_samples = min_samples
+ @property
+ def update_samples(self) -> int:
+ return self._lag + self.season_length * self.window_size
+
# %% ../nbs/lag_transforms.ipynb 11
class SeasonalRollingMean(_Seasonal_RollingBase): ...
@@ -183,6 +204,10 @@ class _ExpandingBase(_BaseLagTransform):
def __init__(self): ...
+ @property
+ def update_samples(self) -> int:
+ return 1
+
# %% ../nbs/lag_transforms.ipynb 14
class ExpandingMean(_ExpandingBase): ...
@@ -200,6 +225,10 @@ class ExpandingQuantile(_ExpandingBase):
def __init__(self, p: float):
self.p = p
+ @property
+ def update_samples(self) -> int:
+ return -1
+
# %% ../nbs/lag_transforms.ipynb 16
class ExponentiallyWeightedMean(_BaseLagTransform):
"""Exponentially weighted average
@@ -212,6 +241,10 @@ class ExponentiallyWeightedMean(_BaseLagTransform):
def __init__(self, alpha: float):
self.alpha = alpha
+ @property
+ def update_samples(self) -> int:
+ return 1
+
# %% ../nbs/lag_transforms.ipynb 18
class Offset(_BaseLagTransform):
"""Shift series before computing transformation
@@ -231,9 +264,14 @@ def _get_name(self, lag: int) -> str:
return self.tfm._get_name(lag + self.n)
def _set_core_tfm(self, lag: int) -> "Offset":
- self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
+ self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
+ self._core_tfm = self.tfm._core_tfm
return self
+ @property
+ def update_samples(self) -> int:
+ return self.tfm.update_samples + self.n
+
# %% ../nbs/lag_transforms.ipynb 20
class Combine(_BaseLagTransform):
"""Combine two lag transformations using an operator
@@ -269,3 +307,7 @@ def transform(self, ga: CoreGroupedArray) -> np.ndarray:
def update(self, ga: CoreGroupedArray) -> np.ndarray:
return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))
+
+ @property
+ def update_samples(self):
+ return max(self.tfm1.update_samples, self.tfm2.update_samples)
diff --git a/nbs/core.ipynb b/nbs/core.ipynb
index 1cb2fca6..1fefe533 100644
--- a/nbs/core.ipynb
+++ b/nbs/core.ipynb
@@ -20,7 +20,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"#|hide\n",
"%load_ext autoreload\n",
@@ -916,8 +925,18 @@
" self._dropped_series = None\n",
"\n",
" # once we've computed the features and target we can slice the series\n",
+ " update_samples = [\n",
+ " getattr(tfm, 'update_samples', -1) for tfm in self.transforms.values()\n",
+ " ]\n",
+ " if (\n",
+ " self.keep_last_n is None\n",
+ " and update_samples\n",
+ " and all(samples > 0 for samples in update_samples)\n",
+ " ):\n",
+ " # user didn't set keep_last_n and we can infer it from the transforms\n",
+ " self.keep_last_n = max(update_samples)\n",
" if self.keep_last_n is not None:\n",
- " self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None)) \n",
+ " self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))\n",
" del self._restore_idxs, self._sort_idxs\n",
"\n",
" # lag transforms\n",
@@ -1671,7 +1690,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
@@ -1692,7 +1711,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
@@ -1978,7 +1997,43 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "---\n",
+ "\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
+ "## TimeSeries.predict\n",
+ "\n",
+ "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n",
+ "> [sklearn.base.BaseEstimator]]], horizon:int,\n",
+ "> before_predict_callback:Optional[Callable]=None,\n",
+ "> after_predict_callback:Optional[Callable]=None,\n",
+ "> X_df:Optional[~DFType]=None,\n",
+ "> ids:Optional[List[str]]=None)"
+ ],
+ "text/plain": [
+ "---\n",
+ "\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
+ "## TimeSeries.predict\n",
+ "\n",
+ "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n",
+ "> [sklearn.base.BaseEstimator]]], horizon:int,\n",
+ "> before_predict_callback:Optional[Callable]=None,\n",
+ "> after_predict_callback:Optional[Callable]=None,\n",
+ "> X_df:Optional[~DFType]=None,\n",
+ "> ids:Optional[List[str]]=None)"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"show_doc(TimeSeries.predict, title_level=2)"
]
@@ -2094,7 +2149,41 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "---\n",
+ "\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
+ "## TimeSeries.update\n",
+ "\n",
+ "> TimeSeries.update\n",
+ "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
+ "> .frame.DataFrame])\n",
+ "\n",
+ "*Update the values of the stored series.*"
+ ],
+ "text/plain": [
+ "---\n",
+ "\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "\n",
+ "## TimeSeries.update\n",
+ "\n",
+ "> TimeSeries.update\n",
+ "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
+ "> .frame.DataFrame])\n",
+ "\n",
+ "*Update the values of the stored series.*"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"show_doc(TimeSeries.update, title_level=2)"
]
@@ -2180,7 +2269,15 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "sys:1: CategoricalRemappingWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n"
+ ]
+ }
+ ],
"source": [
"#| hide\n",
"#| polars\n",
@@ -2429,6 +2526,42 @@
"preds2 = ts2.predict({'model': NaiveModel()}, 10)\n",
"pd.testing.assert_frame_equal(preds, preds2)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# automatically set keep_last_n for built-in lag transforms\n",
+ "ts = TimeSeries(\n",
+ " freq='D',\n",
+ " lags=[1, 2],\n",
+ " date_features=['dayofweek'],\n",
+ " lag_transforms={\n",
+ " 1: [RollingMean(1), RollingMean(4)],\n",
+ " },\n",
+ ")\n",
+ "ts.fit_transform(series, 'unique_id', 'ds', 'y', keep_last_n=20)\n",
+ "assert ts.keep_last_n == 20\n",
+ "ts.fit_transform(series, 'unique_id', 'ds', 'y')\n",
+ "assert ts.keep_last_n == 4\n",
+ "# we can't infer it for functions\n",
+ "ts = TimeSeries(\n",
+ " freq='D',\n",
+ " lags=[1, 2],\n",
+ " date_features=['dayofweek'],\n",
+ " lag_transforms={\n",
+ " 1: [RollingMean(1), RollingMean(4)],\n",
+ " 5: [expanding_mean],\n",
+ " },\n",
+ ")\n",
+ "ts.fit_transform(series, 'unique_id', 'ds', 'y', keep_last_n=20)\n",
+ "assert ts.keep_last_n == 20\n",
+ "ts.fit_transform(series, 'unique_id', 'ds', 'y')\n",
+ "assert ts.keep_last_n is None"
+ ]
}
],
"metadata": {
diff --git a/nbs/forecast.ipynb b/nbs/forecast.ipynb
index d9fc0699..56afd42e 100644
--- a/nbs/forecast.ipynb
+++ b/nbs/forecast.ipynb
@@ -1306,7 +1306,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L127){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast\n",
"\n",
@@ -1337,7 +1337,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L127){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast\n",
"\n",
@@ -1433,7 +1433,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L447){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L446){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit\n",
"\n",
@@ -1467,7 +1467,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L447){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L446){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit\n",
"\n",
@@ -1548,7 +1548,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L966){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L965){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.save\n",
"\n",
@@ -1564,7 +1564,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L966){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L965){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.save\n",
"\n",
@@ -1598,7 +1598,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L983){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L982){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.load\n",
"\n",
@@ -1614,7 +1614,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L983){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L982){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.load\n",
"\n",
@@ -1648,7 +1648,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1006){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1005){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.update\n",
"\n",
@@ -1666,7 +1666,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1006){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L1005){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.update\n",
"\n",
@@ -1702,7 +1702,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L576){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L575){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.make_future_dataframe\n",
"\n",
@@ -1718,7 +1718,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L576){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L575){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.make_future_dataframe\n",
"\n",
@@ -1826,40 +1826,36 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L600){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L599){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.get_missing_future\n",
"\n",
- "> MLForecast.get_missing_future (h:int,\n",
- "> X_df:Union[pandas.core.frame.DataFrame,pol\n",
- "> ars.dataframe.frame.DataFrame])\n",
+ "> MLForecast.get_missing_future (h:int, X_df:~DFType)\n",
"\n",
"*Get the missing id and time combinations in `X_df`.*\n",
"\n",
"| | **Type** | **Details** |\n",
"| -- | -------- | ----------- |\n",
"| h | int | Number of periods to predict. |\n",
- "| X_df | Union | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
- "| **Returns** | **Union** | **DataFrame with expected ids and future times missing in `X_df`** |"
+ "| X_df | DFType | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
+ "| **Returns** | **DFType** | **DataFrame with expected ids and future times missing in `X_df`** |"
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L600){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L599){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.get_missing_future\n",
"\n",
- "> MLForecast.get_missing_future (h:int,\n",
- "> X_df:Union[pandas.core.frame.DataFrame,pol\n",
- "> ars.dataframe.frame.DataFrame])\n",
+ "> MLForecast.get_missing_future (h:int, X_df:~DFType)\n",
"\n",
"*Get the missing id and time combinations in `X_df`.*\n",
"\n",
"| | **Type** | **Details** |\n",
"| -- | -------- | ----------- |\n",
"| h | int | Number of periods to predict. |\n",
- "| X_df | Union | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
- "| **Returns** | **Union** | **DataFrame with expected ids and future times missing in `X_df`** |"
+ "| X_df | DFType | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
+ "| **Returns** | **DFType** | **DataFrame with expected ids and future times missing in `X_df`** |"
]
},
"execution_count": null,
@@ -1896,7 +1892,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L548){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L547){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.forecast_fitted_values\n",
"\n",
@@ -1914,7 +1910,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L548){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L547){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.forecast_fitted_values\n",
"\n",
@@ -2337,18 +2333,16 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L619){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L618){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
"> MLForecast.predict (h:int,\n",
"> before_predict_callback:Optional[Callable]=None,\n",
- "> after_predict_callback:Optional[Callable]=None, new_d\n",
- "> f:Union[pandas.core.frame.DataFrame,polars.dataframe.\n",
- "> frame.DataFrame,NoneType]=None,\n",
- "> level:Optional[List[Union[int,float]]]=None, X_df:Uni\n",
- "> on[pandas.core.frame.DataFrame,polars.dataframe.frame\n",
- "> .DataFrame,NoneType]=None,\n",
+ "> after_predict_callback:Optional[Callable]=None,\n",
+ "> new_df:Optional[~DFType]=None,\n",
+ "> level:Optional[List[Union[int,float]]]=None,\n",
+ "> X_df:Optional[~DFType]=None,\n",
"> ids:Optional[List[str]]=None)\n",
"\n",
"*Compute the predictions for the next `h` steps.*\n",
@@ -2358,27 +2352,25 @@
"| h | int | | Number of periods to predict. |\n",
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
- "| new_df | Union | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
+ "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
"| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n",
- "| X_df | Union | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
+ "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
"| ids | Optional | None | List with subset of ids seen during training for which the forecasts should be computed. |\n",
- "| **Returns** | **Union** | | **Predictions for each serie and timestep, with one column per model.** |"
+ "| **Returns** | **DFType** | | **Predictions for each serie and timestep, with one column per model.** |"
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L619){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L618){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.predict\n",
"\n",
"> MLForecast.predict (h:int,\n",
"> before_predict_callback:Optional[Callable]=None,\n",
- "> after_predict_callback:Optional[Callable]=None, new_d\n",
- "> f:Union[pandas.core.frame.DataFrame,polars.dataframe.\n",
- "> frame.DataFrame,NoneType]=None,\n",
- "> level:Optional[List[Union[int,float]]]=None, X_df:Uni\n",
- "> on[pandas.core.frame.DataFrame,polars.dataframe.frame\n",
- "> .DataFrame,NoneType]=None,\n",
+ "> after_predict_callback:Optional[Callable]=None,\n",
+ "> new_df:Optional[~DFType]=None,\n",
+ "> level:Optional[List[Union[int,float]]]=None,\n",
+ "> X_df:Optional[~DFType]=None,\n",
"> ids:Optional[List[str]]=None)\n",
"\n",
"*Compute the predictions for the next `h` steps.*\n",
@@ -2388,11 +2380,11 @@
"| h | int | | Number of periods to predict. |\n",
"| before_predict_callback | Optional | None | Function to call on the features before computing the predictions.
This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure.
The series identifier is on the index. |\n",
"| after_predict_callback | Optional | None | Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index. |\n",
- "| new_df | Union | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
+ "| new_df | Optional | None | Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
If `new_df` is not None, the method will generate forecasts for the new observations. |\n",
"| level | Optional | None | Confidence levels between 0 and 100 for prediction intervals. |\n",
- "| X_df | Union | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
+ "| X_df | Optional | None | Dataframe with the future exogenous features. Should have the id column and the time column. |\n",
"| ids | Optional | None | List with subset of ids seen during training for which the forecasts should be computed. |\n",
- "| **Returns** | **Union** | | **Predictions for each serie and timestep, with one column per model.** |"
+ "| **Returns** | **DFType** | | **Predictions for each serie and timestep, with one column per model.** |"
]
},
"execution_count": null,
@@ -2947,13 +2939,11 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L206){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L205){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
- "> MLForecast.preprocess\n",
- "> (df:Union[pandas.core.frame.DataFrame,polars.dataf\n",
- "> rame.frame.DataFrame], id_col:str='unique_id',\n",
+ "> MLForecast.preprocess (df:~DFType, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True, keep_last_n:Optional[int]=None,\n",
@@ -2964,7 +2954,7 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
- "| df | Union | | Series data in long format. |\n",
+ "| df | DFType | | Series data in long format. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n",
"| target_col | str | y | Column that contains the target. |\n",
@@ -2979,13 +2969,11 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L206){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L205){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.preprocess\n",
"\n",
- "> MLForecast.preprocess\n",
- "> (df:Union[pandas.core.frame.DataFrame,polars.dataf\n",
- "> rame.frame.DataFrame], id_col:str='unique_id',\n",
+ "> MLForecast.preprocess (df:~DFType, id_col:str='unique_id',\n",
"> time_col:str='ds', target_col:str='y',\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True, keep_last_n:Optional[int]=None,\n",
@@ -2996,7 +2984,7 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
- "| df | Union | | Series data in long format. |\n",
+ "| df | DFType | | Series data in long format. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
"| time_col | str | ds | Column that identifies each timestep, its values can be timestamps or integers. |\n",
"| target_col | str | y | Column that contains the target. |\n",
@@ -3299,7 +3287,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L262){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L261){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
@@ -3318,7 +3306,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L262){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L261){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.fit_models\n",
"\n",
@@ -3443,15 +3431,13 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L767){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L766){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
- "> MLForecast.cross_validation\n",
- "> (df:Union[pandas.core.frame.DataFrame,polars\n",
- "> .dataframe.frame.DataFrame], n_windows:int,\n",
- "> h:int, id_col:str='unique_id',\n",
- "> time_col:str='ds', target_col:str='y',\n",
+ "> MLForecast.cross_validation (df:~DFType, n_windows:int, h:int,\n",
+ "> id_col:str='unique_id', time_col:str='ds',\n",
+ "> target_col:str='y',\n",
"> step_size:Optional[int]=None,\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True,\n",
@@ -3472,7 +3458,7 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
- "| df | Union | | Series data in long format. |\n",
+ "| df | DFType | | Series data in long format. |\n",
"| n_windows | int | | Number of windows to evaluate. |\n",
"| h | int | | Forecast horizon. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
@@ -3491,20 +3477,18 @@
"| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n",
"| fitted | bool | False | Store the in-sample predictions. |\n",
"| as_numpy | bool | False | Cast features to numpy array. |\n",
- "| **Returns** | **Union** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |"
+ "| **Returns** | **DFType** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |"
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L767){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L766){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.cross_validation\n",
"\n",
- "> MLForecast.cross_validation\n",
- "> (df:Union[pandas.core.frame.DataFrame,polars\n",
- "> .dataframe.frame.DataFrame], n_windows:int,\n",
- "> h:int, id_col:str='unique_id',\n",
- "> time_col:str='ds', target_col:str='y',\n",
+ "> MLForecast.cross_validation (df:~DFType, n_windows:int, h:int,\n",
+ "> id_col:str='unique_id', time_col:str='ds',\n",
+ "> target_col:str='y',\n",
"> step_size:Optional[int]=None,\n",
"> static_features:Optional[List[str]]=None,\n",
"> dropna:bool=True,\n",
@@ -3525,7 +3509,7 @@
"\n",
"| | **Type** | **Default** | **Details** |\n",
"| -- | -------- | ----------- | ----------- |\n",
- "| df | Union | | Series data in long format. |\n",
+ "| df | DFType | | Series data in long format. |\n",
"| n_windows | int | | Number of windows to evaluate. |\n",
"| h | int | | Forecast horizon. |\n",
"| id_col | str | unique_id | Column that identifies each serie. |\n",
@@ -3544,7 +3528,7 @@
"| input_size | Optional | None | Maximum training samples per serie in each window. If None, will use an expanding window. |\n",
"| fitted | bool | False | Store the in-sample predictions. |\n",
"| as_numpy | bool | False | Cast features to numpy array. |\n",
- "| **Returns** | **Union** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |"
+ "| **Returns** | **DFType** | | **Predictions for each window with the series id, timestamp, last train date, target value and predictions from each model.** |"
]
},
"execution_count": null,
@@ -4286,6 +4270,7 @@
" h=horizon,\n",
" step_size=horizon,\n",
" input_size=input_size,\n",
+ " keep_last_n=input_size,\n",
")\n",
"series_lengths = np.diff(fcst.ts.ga.indptr)\n",
"unique_lengths = np.unique(series_lengths)\n",
@@ -4437,7 +4422,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
@@ -4446,7 +4431,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L192){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/forecast.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### MLForecast.from_cv\n",
"\n",
@@ -4480,7 +4465,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "[LightGBM] [Info] Start training from score 0.084340\n",
"[10] mape: 0.118569\n",
"[20] mape: 0.111506\n",
"[30] mape: 0.107314\n",
diff --git a/nbs/lag_transforms.ipynb b/nbs/lag_transforms.ipynb
index 9caac5ee..2eb96f58 100644
--- a/nbs/lag_transforms.ipynb
+++ b/nbs/lag_transforms.ipynb
@@ -113,7 +113,15 @@
" out._core_tfm = transforms[0]._core_tfm.stack(\n",
" [tfm._core_tfm for tfm in transforms]\n",
" )\n",
- " return out"
+ " return out\n",
+ "\n",
+ " @property\n",
+ " def _lag(self):\n",
+ " return self._core_tfm.lag - 1\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return -1"
]
},
{
@@ -125,6 +133,7 @@
"source": [
"#| exporti\n",
"class Lag(_BaseLagTransform):\n",
+ " \n",
" def __init__(self, lag: int):\n",
" self.lag = lag\n",
" self._core_tfm = core_tfms.Lag(lag=lag)\n",
@@ -136,7 +145,11 @@
" return f'lag{lag}'\n",
"\n",
" def __eq__(self, other):\n",
- " return isinstance(other, Lag) and self.lag == other.lag"
+ " return isinstance(other, Lag) and self.lag == other.lag\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return self.lag"
]
},
{
@@ -160,7 +173,11 @@
" If `None`, will be set to `window_size`.\n",
" \"\"\"\n",
" self.window_size = window_size\n",
- " self.min_samples = min_samples"
+ " self.min_samples = min_samples\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return self._lag + self.window_size"
]
},
{
@@ -248,7 +265,11 @@
" \"\"\" \n",
" self.season_length = season_length\n",
" self.window_size = window_size\n",
- " self.min_samples = min_samples"
+ " self.min_samples = min_samples\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return self._lag + self.season_length * self.window_size"
]
},
{
@@ -311,7 +332,11 @@
"class _ExpandingBase(_BaseLagTransform):\n",
" \"\"\"Expanding statistic\"\"\"\n",
" def __init__(self):\n",
- " ..."
+ " ...\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return 1"
]
},
{
@@ -336,7 +361,11 @@
"\n",
"class ExpandingQuantile(_ExpandingBase):\n",
" def __init__(self, p: float):\n",
- " self.p = p"
+ " self.p = p\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return -1"
]
},
{
@@ -378,7 +407,11 @@
" alpha : float\n",
" Smoothing factor.\"\"\"\n",
" def __init__(self, alpha: float):\n",
- " self.alpha = alpha"
+ " self.alpha = alpha\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return 1"
]
},
{
@@ -429,8 +462,13 @@
" return self.tfm._get_name(lag + self.n)\n",
"\n",
" def _set_core_tfm(self, lag: int) -> 'Offset':\n",
- " self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n)\n",
- " return self"
+ " self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n)\n",
+ " self._core_tfm = self.tfm._core_tfm\n",
+ " return self\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self) -> int:\n",
+ " return self.tfm.update_samples + self.n"
]
},
{
@@ -488,7 +526,11 @@
" return self.operator(self.tfm1.transform(ga), self.tfm2.transform(ga))\n",
"\n",
" def update(self, ga: CoreGroupedArray) -> np.ndarray:\n",
- " return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))"
+ " return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))\n",
+ "\n",
+ " @property\n",
+ " def update_samples(self):\n",
+ " return max(self.tfm1.update_samples, self.tfm2.update_samples)"
]
},
{
@@ -499,7 +541,9 @@
"outputs": [],
"source": [
"#| hide\n",
- "import operator"
+ "import operator\n",
+ "\n",
+ "from mlforecast.grouped_array import GroupedArray as MLGroupedArray"
]
},
{
@@ -551,7 +595,14 @@
" tfm._set_core_tfm(1)\n",
" tfm._get_name(1)\n",
" tfm.transform(ga)\n",
- " tfm.update(ga)"
+ " updates = tfm.update(ga)\n",
+ " upd_samples = tfm.update_samples\n",
+ " if upd_samples > -1:\n",
+ " sliced_ga = MLGroupedArray(ga.data, ga.indptr).take_from_groups(slice(-upd_samples, None))\n",
+ " ga2 = CoreGroupedArray(sliced_ga.data, sliced_ga.indptr)\n",
+ " tfm.transform(ga) # to reset state\n",
+ " updates2 = tfm.update(ga2)\n",
+ " np.testing.assert_allclose(updates, updates2)"
]
}
],