diff --git a/darts/models/filtering/gaussian_process_filter.py b/darts/models/filtering/gaussian_process_filter.py index 77e3e099b6..047ff8953a 100644 --- a/darts/models/filtering/gaussian_process_filter.py +++ b/darts/models/filtering/gaussian_process_filter.py @@ -71,4 +71,5 @@ def filter(self, series: TimeSeries, num_samples: int = 1) -> TimeSeries: else: filtered_values = self.model.sample_y(times, n_samples=num_samples) + filtered_values = filtered_values.reshape(len(times), -1, num_samples) return TimeSeries.from_times_and_values(series.time_index, filtered_values)