diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index f0c71771e1..003d88f7fa 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -213,10 +213,15 @@ def train_model( ckpt_path=ckpt_path, ) - logger.info(f"Loading best model from {checkpoint.best_model_path}") - best_model = training_network.load_from_checkpoint( - checkpoint.best_model_path - ) + if checkpoint.best_model_path != "": + logger.info( + f"Loading best model from {checkpoint.best_model_path}" + ) + best_model = training_network.load_from_checkpoint( + checkpoint.best_model_path + ) + else: + best_model = training_network return TrainOutput( transformation=transformation,