Skip to content

Commit

Permalink
Merge branch 'master' into fix/nbeats-nhits-TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn authored May 18, 2022
2 parents 3cc47f1 + adb66fd commit e1ed821
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 13 deletions.
7 changes: 3 additions & 4 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
if not self._fit_called:
raise_log(
ValueError(
"The model must be fit before calling `predict()`."
"For global models, if `predict()` is called without specifying a series,"
"The model must be fit before calling predict(). "
"For global models, if predict() is called without specifying a series, "
"the model must have been fit on a single training series."
),
logger,
Expand Down Expand Up @@ -1014,8 +1014,7 @@ def predict(
If `series` is given and is a sequence of several time series, this function returns
a sequence where each element contains the corresponding `n` points forecasts.
"""
if series is None and past_covariates is None and future_covariates is None:
super().predict(n, num_samples)
super().predict(n, num_samples)
if self._expect_past_covariates and past_covariates is None:
raise_log(
ValueError(
Expand Down
11 changes: 9 additions & 2 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
raise_log,
suppress_lightning_warnings,
)
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.models.forecasting.forecasting_model import (
ForecastingModel,
GlobalForecastingModel,
)
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule
from darts.timeseries import TimeSeries
from darts.utils.data.encoders import SequentialEncoder
Expand Down Expand Up @@ -830,7 +833,7 @@ def fit_from_dataset(
self
Fitted model.
"""

self._fit_called = True
self._verify_train_dataset_type(train_dataset)
raise_if(
len(train_dataset) == 0,
Expand Down Expand Up @@ -1173,6 +1176,10 @@ def predict_from_dataset(
Sequence[TimeSeries]
Returns one or more forecasts for time series.
"""

# We need to call super's super's method directly, because GlobalForecastingModel expects series:
ForecastingModel.predict(self, n, num_samples)

self._verify_inference_dataset_type(input_series_dataset)

# check that covariates and dimensions are matching what we had during training
Expand Down
10 changes: 7 additions & 3 deletions darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ def helper_fit_predict(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=100,
num_samples=(100 if model._is_probabilistic() else 1),
)

if isinstance(y_hat, TimeSeries):
y_hat = y_hat.quantile_timeseries(0.5)
y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat
else:
y_hat = [ts.quantile_timeseries(0.5) for ts in y_hat]
y_hat = [
ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts
for ts in y_hat
]
return y_hat
10 changes: 10 additions & 0 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,11 @@ def test_head_overshot_sample_axis(self):
result = self.ts.head(20, axis="sample")
self.assertEqual(10, result.n_samples)

def test_head_numeric_time_index(self):
s = TimeSeries.from_values(self.ts.values())
# taking the head should not crash
s.head()

def test_tail_overshot_time_axis(self):
result = self.ts.tail(20)
self.assertEqual(10, result.n_timesteps)
Expand All @@ -1371,6 +1376,11 @@ def test_tail_overshot_sample_axis(self):
result = self.ts.tail(20, axis="sample")
self.assertEqual(10, result.n_samples)

def test_tail_numeric_time_index(self):
s = TimeSeries.from_values(self.ts.values())
# taking the tail should not crash
s.tail()


class TimeSeriesFromDataFrameTestCase(DartsBaseTestClass):
def test_from_dataframe_sunny_day(self):
Expand Down
16 changes: 12 additions & 4 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,8 +1261,12 @@ def head(
"""

axis_str = self._get_dim_name(axis)
display_n = range(min(size, self._xa.sizes[axis_str]))
return self.__class__(self._xa[{axis_str: display_n}])
display_n = min(size, self._xa.sizes[axis_str])

if axis_str == self._time_dim:
return self[:display_n]
else:
return self.__class__(self._xa[{axis_str: range(display_n)}])

def tail(
self, size: Optional[int] = 5, axis: Optional[Union[int, str]] = 0
Expand All @@ -1284,8 +1288,12 @@ def tail(
"""

axis_str = self._get_dim_name(axis)
display_n = range(-min(size, self._xa.sizes[axis_str]), 0)
return self.__class__(self._xa[{axis_str: display_n}])
display_n = min(size, self._xa.sizes[axis_str])

if axis_str == self._time_dim:
return self[-display_n:]
else:
return self.__class__(self._xa[{axis_str: range(-display_n, 0)}])

def concatenate(
self,
Expand Down

0 comments on commit e1ed821

Please sign in to comment.