Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure no restart validation loop in nl.Trainer #11029

Merged
merged 8 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io
from nemo.lightning import (
AutoResume,
NeMoLogger,
OptimizerModule,
Trainer,
configure_no_restart_validation_training_loop,
io,
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
Expand Down Expand Up @@ -680,6 +687,7 @@ def _setup(
tokenizer: Optional[TokenizerType],
model_transform: Optional[Union[PEFT, ModelTransform, Callable]],
) -> Any: # Return type is Any because app_state's type is not specified
configure_no_restart_validation_training_loop(trainer)
_log = log or NeMoLogger()
if resume and isinstance(model_transform, PEFT) and _log.ckpt:
logging.info("Disabling try_restore_best_ckpt restoration for adapters")
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.lightning.pytorch.trainer import Trainer
from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop
from nemo.lightning.resume import AutoResume


Expand Down Expand Up @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode():
"ModelCheckpoint",
"OptimizerModule",
"Trainer",
"configure_no_restart_validation_training_loop",
"get_vocab_size",
"teardown",
]
26 changes: 25 additions & 1 deletion nemo/lightning/pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from copy import deepcopy

import fiddle as fdl
import pytorch_lightning as pl
from pytorch_lightning.loops import _TrainingEpochLoop
from typing_extensions import Self

from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.fabric.fabric import Fabric
from nemo.lightning.io.mixin import IOMixin, serialization, track_io


class Trainer(pl.Trainer, IOMixin):
class NoValOnRestartTrainingLoop(_TrainingEpochLoop):
"""
Extend the PTL Epoch loop to skip validation when restarting.
This happens when resuming a checkpoint that has already run validation, but loading restores
the training state before validation has run.
"""

def _should_check_val_fx(self, data_fetcher) -> bool:
if self.restarting:
return False
return super()._should_check_val_fx(data_fetcher)


def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None:
if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop):
warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning)
return

## Pass trainer object to avoid trainer getting overwritten as None
loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps)
trainer.fit_loop.epoch_loop = loop


class Trainer(pl.Trainer, IOMixin):
def add_io(self, obj):
"""Recurse to the leaves of a container and add io functionality to non-serializable leaves"""
if isinstance(obj, (dict, list)):
Expand Down
Loading