Skip to content

Commit

Permalink
Merge branch 'dev' into cop_eval_all_levels
Browse files Browse the repository at this point in the history
  • Loading branch information
rshyamsundar authored Aug 2, 2023
2 parents e58cb15 + 851f3be commit c5ef110
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/mx/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# flake8: noqa: F401
# flake8: noqa: F401, F403

from .component import *
from .serde import *
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/nursery/daf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ tensorboard==2.3.0
numpy==1.22.0
pandas==1.1.5
scikit-learn==0.23.2
scipy==1.5.2
scipy==1.10.0
matplotlib==3.3.2
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from copy import deepcopy
from functools import partial
from typing import List, Tuple
from typing import List, Tuple, Optional

from mxnet.gluon import HybridBlock
import numpy as np
Expand Down Expand Up @@ -41,8 +41,9 @@
SelectFields,
SimpleTransformation,
Transformation,
MissingValueImputation,
RollingMeanValueImputation,
)

from gluonts.nursery.temporal_hierarchical_forecasting.model.cop_deepar.gluonts_fixes import (
batchify_with_dict,
DeepAREstimatorForCOP,
Expand Down Expand Up @@ -171,6 +172,9 @@ def __init__(
return_forecasts_at_all_levels: bool = False,
naive_reconciliation: bool = False,
dtype: Type = np.float32,
impute_missing_values: bool = False,
imputation_method: Optional[MissingValueImputation] = None,
num_imputation_samples: int = 1,
) -> None:
super().__init__(trainer=trainer, dtype=dtype)

Expand Down Expand Up @@ -202,11 +206,20 @@ def __init__(

assert self.base_estimator_type == DeepAREstimatorForCOP

if "distr_output" not in base_estimator_hps:
base_estimator_hps["distr_output"] = GaussianOutput()
base_estimator_hps.setdefault("distr_output", GaussianOutput())

print(f"Distribution output: {base_estimator_hps['distr_output']}")

base_estimator_hps.setdefault(
"impute_missing_values", impute_missing_values
)

base_estimator_hps.setdefault("imputation_method", imputation_method)

base_estimator_hps.setdefault(
"num_imputation_samples", num_imputation_samples
)

self.estimators = []
for agg_multiple, freq_str in zip(
self.temporal_hierarchy.agg_multiples,
Expand All @@ -223,6 +236,14 @@ def __init__(
num_nodes = self.temporal_hierarchy.num_leaves // agg_multiple
lags_seq = [lag for lag in lags_seq if lag >= num_nodes]

# adapt window_length if RollingMeanValueImputation is used
if isinstance(imputation_method, RollingMeanValueImputation):
base_estimator_hps_agg[
"imputation_method"
] = RollingMeanValueImputation(
window_size=imputation_method.window_size // agg_multiple
)

# Hack to enforce correct serialization of lags_seq and history length
# (only works when set in constructor).
if agg_multiple != 1:
Expand Down
13 changes: 9 additions & 4 deletions src/gluonts/torch/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,15 @@ def train_model(
ckpt_path=ckpt_path,
)

logger.info(f"Loading best model from {checkpoint.best_model_path}")
best_model = training_network.load_from_checkpoint(
checkpoint.best_model_path
)
if checkpoint.best_model_path != "":
logger.info(
f"Loading best model from {checkpoint.best_model_path}"
)
best_model = training_network.load_from_checkpoint(
checkpoint.best_model_path
)
else:
best_model = training_network

return TrainOutput(
transformation=transformation,
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# flake8: noqa: F401
# flake8: noqa: F401, F403

from .component import *
from .model.forecast_generator import *
Expand Down

0 comments on commit c5ef110

Please sign in to comment.