Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/RegressionEnsemble with single model regressor and coef access in LinearRegressionModel #2205

Merged
merged 7 commits into from
Feb 5, 2024
Merged
4 changes: 4 additions & 0 deletions darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def fit(

for quantile in self.quantiles:
self.kwargs["quantile"] = quantile
# assign the Quantile regressor to self.model to leverage existing logic
self.model = QuantileRegressor(**self.kwargs)
super().fit(
series=series,
Expand All @@ -256,6 +257,9 @@ def fit(

self._model_container[quantile] = self.model

# replace the last trained QuantileRegressor with the dictionnary of Regressors.
self.model = self._model_container

return self

else:
Expand Down
5 changes: 5 additions & 0 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def __init__(
lags=None, lags_future_covariates=[0], fit_intercept=False
)
elif isinstance(regression_model, RegressionModel):
raise_if_not(
regression_model.multi_models,
"`regression_model.multi_models = False` is not supported for `RegressionEnsembleModel.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)
regression_model = regression_model
else:
# scikit-learn like model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,31 @@ def test_predict_likelihood_parameters_multivariate_regression_ensemble(self):
pred_ens["linear_q0.05"].values() < pred_ens["linear_q0.50"].values()
) and all(pred_ens["linear_q0.50"].values() < pred_ens["linear_q0.95"].values())

def test_wrong_model_creation_params(self):
"""Since `multi_models=False` requires to shift the regression model lags in the past (outside of the forecasting
model predictions), it is not supported."""
forcasting_models = [
self.get_deterministic_global_model(2),
self.get_deterministic_global_model([-5, -7]),
]
RegressionEnsembleModel(
forecasting_models=forcasting_models,
regression_train_n_points=10,
regression_model=LinearRegressionModel(
lags_future_covariates=[0], output_chunk_length=2, multi_models=True
),
)
with pytest.raises(ValueError):
RegressionEnsembleModel(
forecasting_models=forcasting_models,
regression_train_n_points=10,
regression_model=LinearRegressionModel(
lags_future_covariates=[0],
output_chunk_length=2,
multi_models=False,
),
)

@staticmethod
def get_probabilistic_global_model(
lags: Union[int, List[int]],
Expand Down
Loading