From 08b59bc6e16242db6ee16cc5b1d129aeacc8a45b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 31 Jul 2023 14:44:41 +0200 Subject: [PATCH 1/2] [torch] Return a model even if callback has no best model path (#2952) --- src/gluonts/torch/model/estimator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 9311b474d0..2daa702d6b 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -209,10 +209,15 @@ def train_model( ckpt_path=ckpt_path, ) - logger.info(f"Loading best model from {checkpoint.best_model_path}") - best_model = training_network.load_from_checkpoint( - checkpoint.best_model_path - ) + if checkpoint.best_model_path != "": + logger.info( + f"Loading best model from {checkpoint.best_model_path}" + ) + best_model = training_network.load_from_checkpoint( + checkpoint.best_model_path + ) + else: + best_model = training_network return TrainOutput( transformation=transformation, From 03b2f80c4629c90f32ddfde0f9d31179b2c5dab7 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Mon, 16 Oct 2023 09:43:21 +0200 Subject: [PATCH 2/2] Move from `pytorch_lightning` to `lightning` (#3013) --- .../advanced_topics/howto_pytorch_lightning.md.template | 2 +- requirements/requirements-pytorch.txt | 4 +++- src/gluonts/torch/model/d_linear/estimator.py | 2 +- src/gluonts/torch/model/d_linear/lightning_module.py | 2 +- src/gluonts/torch/model/deepar/lightning_module.py | 2 +- src/gluonts/torch/model/estimator.py | 4 ++-- src/gluonts/torch/model/lag_tst/estimator.py | 2 +- src/gluonts/torch/model/lag_tst/lightning_module.py | 2 +- src/gluonts/torch/model/lightning_util.py | 2 +- src/gluonts/torch/model/mqf2/lightning_module.py | 2 +- src/gluonts/torch/model/patch_tst/estimator.py | 2 +- src/gluonts/torch/model/patch_tst/lightning_module.py | 2 +- src/gluonts/torch/model/simple_feedforward/estimator.py | 2 +- .../torch/model/simple_feedforward/lightning_module.py | 2 +- src/gluonts/torch/model/tft/lightning_module.py | 2 +- 15 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 01e170951e..4d54d5636a 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -134,7 +134,7 @@ To train the model using PyTorch Lightning, we only need to extend the class wit ```python -import pytorch_lightning as pl +import lightning.pytorch as pl ``` diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt index 16a40f64a3..03f4e997ab 100644 --- a/requirements/requirements-pytorch.txt +++ b/requirements/requirements-pytorch.txt @@ -1,5 +1,7 @@ torch>=1.9,<3 -pytorch-lightning>=1.5,<3 +lightning>=1.8,<2.2 +# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually +pytorch_lightning>=1.8,<2.2 # Need to pin protobuf (for now) # See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 protobuf~=3.19.0 diff --git a/src/gluonts/torch/model/d_linear/estimator.py b/src/gluonts/torch/model/d_linear/estimator.py index 4f58caf4a0..e3f428db86 100644 --- a/src/gluonts/torch/model/d_linear/estimator.py +++ b/src/gluonts/torch/model/d_linear/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/d_linear/lightning_module.py b/src/gluonts/torch/model/d_linear/lightning_module.py index 28dccf1b97..bd081b45dd 100644 --- a/src/gluonts/torch/model/d_linear/lightning_module.py +++ b/src/gluonts/torch/model/d_linear/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/deepar/lightning_module.py b/src/gluonts/torch/model/deepar/lightning_module.py index fc676dfab3..8d190e2329 100644 --- a/src/gluonts/torch/model/deepar/lightning_module.py +++ b/src/gluonts/torch/model/deepar/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 2daa702d6b..9f41282ebb 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -15,7 +15,7 @@ import logging import numpy as np -import pytorch_lightning as pl +import lightning.pytorch as pl import torch.nn as nn from gluonts.core.component import validated @@ -213,7 +213,7 @@ def train_model( logger.info( f"Loading best model from {checkpoint.best_model_path}" ) - best_model = training_network.load_from_checkpoint( + best_model = training_network.__class__.load_from_checkpoint( checkpoint.best_model_path ) else: diff --git a/src/gluonts/torch/model/lag_tst/estimator.py b/src/gluonts/torch/model/lag_tst/estimator.py index 27bfd253b2..c3ae48237a 100644 --- a/src/gluonts/torch/model/lag_tst/estimator.py +++ b/src/gluonts/torch/model/lag_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any, List import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/lag_tst/lightning_module.py b/src/gluonts/torch/model/lag_tst/lightning_module.py index 2510944cfa..5c9e70e9e4 100644 --- a/src/gluonts/torch/model/lag_tst/lightning_module.py +++ b/src/gluonts/torch/model/lag_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/lightning_util.py b/src/gluonts/torch/model/lightning_util.py index 6742c8c7cf..73e2396140 100644 --- a/src/gluonts/torch/model/lightning_util.py +++ b/src/gluonts/torch/model/lightning_util.py @@ -13,7 +13,7 @@ from packaging import version -import pytorch_lightning as pl +import lightning.pytorch as pl def has_validation_loop(trainer: pl.Trainer): diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..16916c3c41 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -13,7 +13,7 @@ from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index c84576916f..860c18e193 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/patch_tst/lightning_module.py b/src/gluonts/torch/model/patch_tst/lightning_module.py index f5e95158b2..d80137ae05 100644 --- a/src/gluonts/torch/model/patch_tst/lightning_module.py +++ b/src/gluonts/torch/model/patch_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/simple_feedforward/estimator.py b/src/gluonts/torch/model/simple_feedforward/estimator.py index 5f82640e8b..05f34a9fa2 100644 --- a/src/gluonts/torch/model/simple_feedforward/estimator.py +++ b/src/gluonts/torch/model/simple_feedforward/estimator.py @@ -14,7 +14,7 @@ from typing import List, Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/simple_feedforward/lightning_module.py b/src/gluonts/torch/model/simple_feedforward/lightning_module.py index b7cf9a529a..f03473e78d 100644 --- a/src/gluonts/torch/model/simple_feedforward/lightning_module.py +++ b/src/gluonts/torch/model/simple_feedforward/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/tft/lightning_module.py b/src/gluonts/torch/model/tft/lightning_module.py index f6f7daa335..4647d740fd 100644 --- a/src/gluonts/torch/model/tft/lightning_module.py +++ b/src/gluonts/torch/model/tft/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated from gluonts.itertools import select