Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: infer samples required for built-in lag transforms updates #445

Merged
merged 5 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Combine.update': ( 'lag_transforms.html#combine.update',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Combine.update_samples': ( 'lag_transforms.html#combine.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMax': ( 'lag_transforms.html#expandingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingMean': ( 'lag_transforms.html#expandingmean',
Expand All @@ -272,12 +274,16 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingQuantile.__init__': ( 'lag_transforms.html#expandingquantile.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingQuantile.update_samples': ( 'lag_transforms.html#expandingquantile.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExpandingStd': ( 'lag_transforms.html#expandingstd',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean': ( 'lag_transforms.html#exponentiallyweightedmean',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean.__init__': ( 'lag_transforms.html#exponentiallyweightedmean.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.ExponentiallyWeightedMean.update_samples': ( 'lag_transforms.html#exponentiallyweightedmean.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag': ('lag_transforms.html#lag', 'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag.__eq__': ( 'lag_transforms.html#lag.__eq__',
'mlforecast/lag_transforms.py'),
Expand All @@ -287,6 +293,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag._set_core_tfm': ( 'lag_transforms.html#lag._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Lag.update_samples': ( 'lag_transforms.html#lag.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset': ( 'lag_transforms.html#offset',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset.__init__': ( 'lag_transforms.html#offset.__init__',
Expand All @@ -295,6 +303,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset._set_core_tfm': ( 'lag_transforms.html#offset._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.Offset.update_samples': ( 'lag_transforms.html#offset.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMax': ( 'lag_transforms.html#rollingmax',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms.RollingMean': ( 'lag_transforms.html#rollingmean',
Expand Down Expand Up @@ -327,6 +337,8 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform._get_name': ( 'lag_transforms.html#_baselagtransform._get_name',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform._lag': ( 'lag_transforms.html#_baselagtransform._lag',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform._set_core_tfm': ( 'lag_transforms.html#_baselagtransform._set_core_tfm',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform.stack': ( 'lag_transforms.html#_baselagtransform.stack',
Expand All @@ -337,18 +349,26 @@
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform.update': ( 'lag_transforms.html#_baselagtransform.update',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._BaseLagTransform.update_samples': ( 'lag_transforms.html#_baselagtransform.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._ExpandingBase': ( 'lag_transforms.html#_expandingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._ExpandingBase.__init__': ( 'lag_transforms.html#_expandingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._ExpandingBase.update_samples': ( 'lag_transforms.html#_expandingbase.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._RollingBase': ( 'lag_transforms.html#_rollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._RollingBase.__init__': ( 'lag_transforms.html#_rollingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._RollingBase.update_samples': ( 'lag_transforms.html#_rollingbase.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._Seasonal_RollingBase': ( 'lag_transforms.html#_seasonal_rollingbase',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._Seasonal_RollingBase.__init__': ( 'lag_transforms.html#_seasonal_rollingbase.__init__',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._Seasonal_RollingBase.update_samples': ( 'lag_transforms.html#_seasonal_rollingbase.update_samples',
'mlforecast/lag_transforms.py'),
'mlforecast.lag_transforms._pascal2camel': ( 'lag_transforms.html#_pascal2camel',
'mlforecast/lag_transforms.py')},
'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'),
Expand Down
10 changes: 10 additions & 0 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,16 @@ def _transform(
self._dropped_series = None

# once we've computed the features and target we can slice the series
update_samples = [
getattr(tfm, "update_samples", -1) for tfm in self.transforms.values()
]
if (
self.keep_last_n is None
and update_samples
and all(samples > 0 for samples in update_samples)
):
# user didn't set keep_last_n and we can infer it from the transforms
self.keep_last_n = max(update_samples)
if self.keep_last_n is not None:
self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))
del self._restore_idxs, self._sort_idxs
Expand Down
44 changes: 43 additions & 1 deletion mlforecast/lag_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,17 @@ def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform":
)
return out

@property
def _lag(self):
return self._core_tfm.lag - 1

@property
def update_samples(self) -> int:
return -1

# %% ../nbs/lag_transforms.ipynb 6
class Lag(_BaseLagTransform):

def __init__(self, lag: int):
self.lag = lag
self._core_tfm = core_tfms.Lag(lag=lag)
Expand All @@ -83,6 +92,10 @@ def _get_name(self, lag: int) -> str:
def __eq__(self, other):
return isinstance(other, Lag) and self.lag == other.lag

@property
def update_samples(self) -> int:
return self.lag

# %% ../nbs/lag_transforms.ipynb 7
class _RollingBase(_BaseLagTransform):
"Rolling statistic"
Expand All @@ -100,6 +113,10 @@ def __init__(self, window_size: int, min_samples: Optional[int] = None):
self.window_size = window_size
self.min_samples = min_samples

@property
def update_samples(self) -> int:
return self._lag + self.window_size

# %% ../nbs/lag_transforms.ipynb 8
class RollingMean(_RollingBase): ...

Expand Down Expand Up @@ -149,6 +166,10 @@ def __init__(
self.window_size = window_size
self.min_samples = min_samples

@property
def update_samples(self) -> int:
return self._lag + self.season_length * self.window_size

# %% ../nbs/lag_transforms.ipynb 11
class SeasonalRollingMean(_Seasonal_RollingBase): ...

Expand Down Expand Up @@ -183,6 +204,10 @@ class _ExpandingBase(_BaseLagTransform):

def __init__(self): ...

@property
def update_samples(self) -> int:
return 1

# %% ../nbs/lag_transforms.ipynb 14
class ExpandingMean(_ExpandingBase): ...

Expand All @@ -200,6 +225,10 @@ class ExpandingQuantile(_ExpandingBase):
def __init__(self, p: float):
self.p = p

@property
def update_samples(self) -> int:
return -1

# %% ../nbs/lag_transforms.ipynb 16
class ExponentiallyWeightedMean(_BaseLagTransform):
"""Exponentially weighted average
Expand All @@ -212,6 +241,10 @@ class ExponentiallyWeightedMean(_BaseLagTransform):
def __init__(self, alpha: float):
self.alpha = alpha

@property
def update_samples(self) -> int:
return 1

# %% ../nbs/lag_transforms.ipynb 18
class Offset(_BaseLagTransform):
"""Shift series before computing transformation
Expand All @@ -231,9 +264,14 @@ def _get_name(self, lag: int) -> str:
return self.tfm._get_name(lag + self.n)

def _set_core_tfm(self, lag: int) -> "Offset":
self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
self._core_tfm = self.tfm._core_tfm
return self

@property
def update_samples(self) -> int:
return self.tfm.update_samples + self.n

# %% ../nbs/lag_transforms.ipynb 20
class Combine(_BaseLagTransform):
"""Combine two lag transformations using an operator
Expand Down Expand Up @@ -269,3 +307,7 @@ def transform(self, ga: CoreGroupedArray) -> np.ndarray:

def update(self, ga: CoreGroupedArray) -> np.ndarray:
return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))

@property
def update_samples(self):
return max(self.tfm1.update_samples, self.tfm2.update_samples)
Loading
Loading