Skip to content

Commit

Permalink
Merge branch 'dev' into cop_estimator_freq
Browse files Browse the repository at this point in the history
  • Loading branch information
rshyamsundar authored Aug 2, 2023
2 parents 0543b12 + 851f3be commit 627e794
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/nursery/daf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ tensorboard==2.3.0
numpy==1.22.0
pandas==1.1.5
scikit-learn==0.23.2
scipy==1.5.2
scipy==1.10.0
matplotlib==3.3.2
13 changes: 9 additions & 4 deletions src/gluonts/torch/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,15 @@ def train_model(
ckpt_path=ckpt_path,
)

logger.info(f"Loading best model from {checkpoint.best_model_path}")
best_model = training_network.load_from_checkpoint(
checkpoint.best_model_path
)
if checkpoint.best_model_path != "":
logger.info(
f"Loading best model from {checkpoint.best_model_path}"
)
best_model = training_network.load_from_checkpoint(
checkpoint.best_model_path
)
else:
best_model = training_network

return TrainOutput(
transformation=transformation,
Expand Down

0 comments on commit 627e794

Please sign in to comment.