From a1f7c086ca62fe9608be6e7605793b4be192e963 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 24 Oct 2024 12:52:58 -0700 Subject: [PATCH 1/8] Configure no restart validation loop in nl.Trainer Signed-off-by: Hemil Desai --- nemo/lightning/pytorch/trainer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0d71c49bf198..0d1e8a9b1199 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -12,18 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy import fiddle as fdl import pytorch_lightning as pl +from pytorch_lightning.loops import _TrainingEpochLoop from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric from nemo.lightning.fabric.fabric import Fabric from nemo.lightning.io.mixin import IOMixin, serialization, track_io +from nemo.utils.exp_manager import SkipResumeTrainingValidationLoop class Trainer(pl.Trainer, IOMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _configure_no_restart_validation_training_loop(self) -> None: + if not isinstance(self.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = SkipResumeTrainingValidationLoop(self, self.min_steps, self.max_steps) + self.fit_loop.epoch_loop = loop def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" From b46de13014b35b2fe56006fd793f71f1f613d01f Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 24 Oct 2024 12:54:49 -0700 Subject: [PATCH 2/8] fix Signed-off-by: Hemil Desai --- nemo/lightning/pytorch/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0d1e8a9b1199..06417dfadbe8 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -29,6 +29,7 @@ class Trainer(pl.Trainer, IOMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._configure_no_restart_validation_training_loop() def _configure_no_restart_validation_training_loop(self) -> None: if not isinstance(self.fit_loop.epoch_loop, _TrainingEpochLoop): From 8d72c59c94e9ababa5c4e314dd4fc6a1dddf90c3 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 25 Oct 2024 13:38:51 -0700 Subject: [PATCH 3/8] Skip validation whenever restarting=True Signed-off-by: Hemil Desai --- nemo/utils/exp_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b512bc57cbab..cdee85159377 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -1279,7 +1279,7 @@ class SkipResumeTrainingValidationLoop(_TrainingEpochLoop): """ def _should_check_val_fx(self, data_fetcher) -> bool: - if self.restarting and self.global_step % self.trainer.val_check_batch == 0: + if self.restarting: return False return super()._should_check_val_fx(data_fetcher) From 51d786636a244cdafdb76fd57376e739b1adf04c Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 4 Nov 2024 16:14:36 -0800 Subject: [PATCH 4/8] PR feedback Signed-off-by: Hemil Desai --- nemo/lightning/pytorch/trainer.py | 15 +++++++++++++-- nemo/utils/exp_manager.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 06417dfadbe8..252c4bd0fb14 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -23,9 +23,20 @@ from nemo.lightning.fabric.conversion import to_fabric from nemo.lightning.fabric.fabric import Fabric from nemo.lightning.io.mixin import IOMixin, serialization, track_io -from nemo.utils.exp_manager import SkipResumeTrainingValidationLoop +class NoValOnRestartTrainingLoop(_TrainingEpochLoop): + """ + Extend the PTL Epoch loop to skip validation when restarting. + This happens when resuming a checkpoint that has already run validation, but loading restores + the training state before validation has run. + """ + + def _should_check_val_fx(self, data_fetcher) -> bool: + if self.restarting: + return False + return super()._should_check_val_fx(data_fetcher) + class Trainer(pl.Trainer, IOMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -37,7 +48,7 @@ def _configure_no_restart_validation_training_loop(self) -> None: return ## Pass trainer object to avoid trainer getting overwritten as None - loop = SkipResumeTrainingValidationLoop(self, self.min_steps, self.max_steps) + loop = NoValOnRestartTrainingLoop(self, self.min_steps, self.max_steps) self.fit_loop.epoch_loop = loop def add_io(self, obj): diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index cdee85159377..b512bc57cbab 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -1279,7 +1279,7 @@ class SkipResumeTrainingValidationLoop(_TrainingEpochLoop): """ def _should_check_val_fx(self, data_fetcher) -> bool: - if self.restarting: + if self.restarting and self.global_step % self.trainer.val_check_batch == 0: return False return super()._should_check_val_fx(data_fetcher) From 16fa5b2fe0c9ad9ae7c2503e6431390959df6fa1 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Tue, 5 Nov 2024 00:15:27 +0000 Subject: [PATCH 5/8] Apply isort and black reformatting Signed-off-by: hemildesai --- nemo/lightning/pytorch/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 252c4bd0fb14..010aac26f78b 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -37,6 +37,7 @@ def _should_check_val_fx(self, data_fetcher) -> bool: return False return super()._should_check_val_fx(data_fetcher) + class Trainer(pl.Trainer, IOMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From 741161c2b36b1ae8d4204305df27dc7e0b205333 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 8 Nov 2024 14:20:49 -0800 Subject: [PATCH 6/8] fix tests Signed-off-by: Hemil Desai --- nemo/collections/llm/api.py | 10 +++++++++- nemo/lightning/__init__.py | 3 ++- nemo/lightning/pytorch/trainer.py | 14 ++++++++++---- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 13f25eb21087..fdceff5d959e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,7 +25,14 @@ from typing_extensions import Annotated import nemo.lightning as nl -from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, + io, +) from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging @@ -680,6 +687,7 @@ def _setup( tokenizer: Optional[TokenizerType], model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) _log = log or NeMoLogger() if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 2cc720e148d4..91d3b3f936d0 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -33,7 +33,7 @@ from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -from nemo.lightning.pytorch.trainer import Trainer +from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode(): "ModelCheckpoint", "OptimizerModule", "Trainer", + "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ] diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 010aac26f78b..0cba14076c52 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -38,11 +38,17 @@ def _should_check_val_fx(self, data_fetcher) -> bool: return super()._should_check_val_fx(data_fetcher) -class Trainer(pl.Trainer, IOMixin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._configure_no_restart_validation_training_loop() +def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: + if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) + trainer.fit_loop.epoch_loop = loop + +class Trainer(pl.Trainer, IOMixin): def _configure_no_restart_validation_training_loop(self) -> None: if not isinstance(self.fit_loop.epoch_loop, _TrainingEpochLoop): warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) From 98c25bbefa114e9c4278787a95c5a18095733ce9 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 8 Nov 2024 14:33:59 -0800 Subject: [PATCH 7/8] fix Signed-off-by: Hemil Desai --- nemo/lightning/pytorch/trainer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0cba14076c52..01f9a57de313 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -49,15 +49,6 @@ def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: class Trainer(pl.Trainer, IOMixin): - def _configure_no_restart_validation_training_loop(self) -> None: - if not isinstance(self.fit_loop.epoch_loop, _TrainingEpochLoop): - warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) - return - - ## Pass trainer object to avoid trainer getting overwritten as None - loop = NoValOnRestartTrainingLoop(self, self.min_steps, self.max_steps) - self.fit_loop.epoch_loop = loop - def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" if isinstance(obj, (dict, list)): From 3e13003b0f3d8ca51f36dfc287cad692eaa91a35 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 13 Nov 2024 10:30:07 -0800 Subject: [PATCH 8/8] use new var Signed-off-by: Maanu Grover --- nemo/lightning/pytorch/trainer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 01f9a57de313..c97c59ef524d 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -18,6 +18,7 @@ import fiddle as fdl import pytorch_lightning as pl from pytorch_lightning.loops import _TrainingEpochLoop +from pytorch_lightning.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric @@ -33,10 +34,20 @@ class NoValOnRestartTrainingLoop(_TrainingEpochLoop): """ def _should_check_val_fx(self, data_fetcher) -> bool: - if self.restarting: + if self.skip_val_on_restart: return False return super()._should_check_val_fx(data_fetcher) + def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: + super().load_state_dict(state_dict, prefix) + + self.skip_val_on_restart = True + + def advance(self, data_fetcher: _DataFetcher) -> None: + super().advance(data_fetcher) + + self.skip_val_on_restart = False + def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop):