diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index cc863d2d03..821700e3cd 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -191,8 +191,8 @@ def predict(self, n: int, num_samples: int = 1) -> TimeSeries: if not self._fit_called: raise_log( ValueError( - "The model must be fit before calling `predict()`." - "For global models, if `predict()` is called without specifying a series," + "The model must be fit before calling predict(). " + "For global models, if predict() is called without specifying a series, " "the model must have been fit on a single training series." ), logger, @@ -1014,8 +1014,7 @@ def predict( If `series` is given and is a sequence of several time series, this function returns a sequence where each element contains the corresponding `n` points forecasts. """ - if series is None and past_covariates is None and future_covariates is None: - super().predict(n, num_samples) + super().predict(n, num_samples) if self._expect_past_covariates and past_covariates is None: raise_log( ValueError( diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index d1423f1b1c..8e3a23ab79 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -40,7 +40,10 @@ raise_log, suppress_lightning_warnings, ) -from darts.models.forecasting.forecasting_model import GlobalForecastingModel +from darts.models.forecasting.forecasting_model import ( + ForecastingModel, + GlobalForecastingModel, +) from darts.models.forecasting.pl_forecasting_module import PLForecastingModule from darts.timeseries import TimeSeries from darts.utils.data.encoders import SequentialEncoder @@ -830,7 +833,7 @@ def fit_from_dataset( self Fitted model. """ - + self._fit_called = True self._verify_train_dataset_type(train_dataset) raise_if( len(train_dataset) == 0, @@ -1173,6 +1176,10 @@ def predict_from_dataset( Sequence[TimeSeries] Returns one or more forecasts for time series. """ + + # We need to call super's super's method directly, because GlobalForecastingModel expects series: + ForecastingModel.predict(self, n, num_samples) + self._verify_inference_dataset_type(input_series_dataset) # check that covariates and dimensions are matching what we had during training diff --git a/darts/tests/models/forecasting/test_TFT.py b/darts/tests/models/forecasting/test_TFT.py index ed42f5620d..3cda144c16 100644 --- a/darts/tests/models/forecasting/test_TFT.py +++ b/darts/tests/models/forecasting/test_TFT.py @@ -282,10 +282,14 @@ def helper_fit_predict( series=series, past_covariates=past_covariates, future_covariates=future_covariates, - num_samples=100, + num_samples=(100 if model._is_probabilistic() else 1), ) + if isinstance(y_hat, TimeSeries): - y_hat = y_hat.quantile_timeseries(0.5) + y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat else: - y_hat = [ts.quantile_timeseries(0.5) for ts in y_hat] + y_hat = [ + ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts + for ts in y_hat + ] return y_hat diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index a173eb86d3..d8520a334b 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -1358,6 +1358,11 @@ def test_head_overshot_sample_axis(self): result = self.ts.head(20, axis="sample") self.assertEqual(10, result.n_samples) + def test_head_numeric_time_index(self): + s = TimeSeries.from_values(self.ts.values()) + # taking the head should not crash + s.head() + def test_tail_overshot_time_axis(self): result = self.ts.tail(20) self.assertEqual(10, result.n_timesteps) @@ -1371,6 +1376,11 @@ def test_tail_overshot_sample_axis(self): result = self.ts.tail(20, axis="sample") self.assertEqual(10, result.n_samples) + def test_tail_numeric_time_index(self): + s = TimeSeries.from_values(self.ts.values()) + # taking the tail should not crash + s.tail() + class TimeSeriesFromDataFrameTestCase(DartsBaseTestClass): def test_from_dataframe_sunny_day(self): diff --git a/darts/timeseries.py b/darts/timeseries.py index 8f31a8d80f..26d581afe4 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -1261,8 +1261,12 @@ def head( """ axis_str = self._get_dim_name(axis) - display_n = range(min(size, self._xa.sizes[axis_str])) - return self.__class__(self._xa[{axis_str: display_n}]) + display_n = min(size, self._xa.sizes[axis_str]) + + if axis_str == self._time_dim: + return self[:display_n] + else: + return self.__class__(self._xa[{axis_str: range(display_n)}]) def tail( self, size: Optional[int] = 5, axis: Optional[Union[int, str]] = 0 @@ -1284,8 +1288,12 @@ def tail( """ axis_str = self._get_dim_name(axis) - display_n = range(-min(size, self._xa.sizes[axis_str]), 0) - return self.__class__(self._xa[{axis_str: display_n}]) + display_n = min(size, self._xa.sizes[axis_str]) + + if axis_str == self._time_dim: + return self[-display_n:] + else: + return self.__class__(self._xa[{axis_str: range(-display_n, 0)}]) def concatenate( self,