Skip to content

Commit

Permalink
unit8co#1101 Implemented min_train_series_length for Theta and FourTheta
Browse files Browse the repository at this point in the history
  • Loading branch information
Rijk van der Meulen committed Aug 1, 2022
1 parent caccce1 commit 6f989af
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Darts is still in an early development phase and we cannot always guarantee back
- An issue with arguments being reverted for the `metric` function of gridsearch and backtest [#989](https://github.com/unit8co/darts/pull/989) by [Clara Grotehans](https://github.com/ClaraGrthns).
- An error checking whether `fit()` has been called in global models [#944](https://github.com/unit8co/darts/pull/944) by [Julien Herzen](https://github.com/hrzn).
- An error in Gaussian Process filter happening with newer versions of sklearn [#963](https://github.com/unit8co/darts/pull/963) by [Julien Herzen](https://github.com/hrzn).
- Implemented the min_train_series_length method for the FourTheta and Theta models that overwrites the minimum default of 3 training samples by 2*seasonal_period when appropriate [#1101](https://github.com/unit8co/darts/pull/1101) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)

### For developers of the library:

Expand Down
22 changes: 22 additions & 0 deletions darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def predict(self, n: int, num_samples: int = 1) -> "TimeSeries":
def __str__(self):
return f"Theta({self.theta})"

@property
def min_train_series_length(self) -> int:
if (
self.season_mode != SeasonalityMode.NONE
and self.seasonality_period
and self.seasonality_period > 1
):
return 2 * self.seasonality_period
else:
return 3


class FourTheta(ForecastingModel):
def __init__(
Expand Down Expand Up @@ -457,3 +468,14 @@ def __str__(self):
return "4Theta(theta:{}, curve:{}, model:{}, seasonality:{})".format(
self.theta, self.trend_mode, self.model_mode, self.season_mode
)

@property
def min_train_series_length(self) -> int:
if (
self.season_mode != SeasonalityMode.NONE
and self.seasonality_period
and self.seasonality_period > 1
):
return 2 * self.seasonality_period
else:
return 3
48 changes: 48 additions & 0 deletions darts/tests/models/forecasting/test_4theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,51 @@ def test_best_model(self):
self.assertTrue(
mape(val_series, forecast_best) <= mape(val_series, forecast_random)
)

def test_min_train_series_length_with_seasonality(self):
seasonality_period = 12
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=seasonality_period,
normalization=False,
)
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=seasonality_period,
)
self.assertEqual(fourtheta.min_train_series_length, 2 * seasonality_period)
self.assertEqual(theta.min_train_series_length, 2 * seasonality_period)

def test_min_train_series_length_without_seasonality(self):
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=None,
normalization=False,
)
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=None,
)
self.assertEqual(fourtheta.min_train_series_length, 3)
self.assertEqual(theta.min_train_series_length, 3)

def test_fit_insufficient_train_series_length(self):
sine_series = st(length=21, freq="MS")
with self.assertRaises(ValueError):
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=12,
)
fourtheta.fit(sine_series)
with self.assertRaises(ValueError):
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=12,
)
theta.fit(sine_series)

0 comments on commit 6f989af

Please sign in to comment.