Skip to content

Commit

Permalink
adding early stop callback to ptuning (#6028)
Browse files Browse the repository at this point in the history
* patch to allow using tokenizers without additional_special_tokens_ids attribute

Signed-off-by: arendu <adithya.r@gmail.com>

* early stop callback for prompt/p tuning

Signed-off-by: arendu <adithya.r@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <adithya.r@gmail.com>

* added exp manager config for early stop

Signed-off-by: arendu <adithya.r@gmail.com>

* pushed logic for creating early stopping inside exp manager

Signed-off-by: arendu <adithya.r@gmail.com>

* pushed logic for creating early stopping inside exp manager

Signed-off-by: arendu <adithya.r@gmail.com>

* minor updates and added dataclass check

Signed-off-by: arendu <adithya.r@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more args

Signed-off-by: arendu <adithya.r@gmail.com>

* more args

Signed-off-by: arendu <adithya.r@gmail.com>

---------

Signed-off-by: arendu <adithya.r@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Davood-M committed Feb 16, 2023
1 parent 6ebc0a6 commit 3ed368d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ trainer:
gradient_clip_val: 1.0
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
benchmark: False



exp_manager:
Expand All @@ -36,6 +37,14 @@ exp_manager:
filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True


model:
seed: 1234
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ exp_manager:
filename: "megatron_t5_prompt_tune--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}"
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True

model:
seed: 1234
Expand Down
24 changes: 24 additions & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.timer import Interval, Timer
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.loops import TrainingEpochLoop
Expand Down Expand Up @@ -69,6 +70,21 @@ class CheckpointMisconfigurationError(NeMoBaseException):
""" Raised when a mismatch between trainer.callbacks and exp_manager occurs"""


@dataclass
class EarlyStoppingParams:
monitor: str = "val_loss" # The metric that early stopping should consider.
mode: str = "min" # inform early stopping whether to look for increase or decrease in monitor.
min_delta: float = 0.001 # smallest change to consider as improvement.
patience: int = 10 # how many (continuous) validation cycles to wait with no improvement and stopping training.
verbose: bool = True
strict: bool = True
check_finite: bool = True
stopping_threshold: Optional[float] = None
divergence_threshold: Optional[float] = None
check_on_train_epoch_end: Optional[bool] = None
log_rank_zero_only: bool = False


@dataclass
class CallbackParams:
filepath: Optional[str] = None # Deprecated
Expand Down Expand Up @@ -152,6 +168,8 @@ class ExpManagerConfig:
# Checkpointing parameters
create_checkpoint_callback: Optional[bool] = True
checkpoint_callback_params: Optional[CallbackParams] = CallbackParams()
create_early_stopping_callback: Optional[bool] = False
early_stopping_callback_params: Optional[EarlyStoppingParams] = EarlyStoppingParams()
# Additional exp_manager arguments
files_to_copy: Optional[List[str]] = None
# logs timing of train/val/test steps
Expand Down Expand Up @@ -270,6 +288,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most
recent checkpoint under ``*last.ckpt``, and the final checkpoint after training completes under ``*end.ckpt``.
Defaults to True.
- create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False.
See EarlyStoppingParams dataclass above.
- files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
copies no files.
- log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False.
Expand Down Expand Up @@ -418,6 +438,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
)
trainer.callbacks.append(ema_callback)

if cfg.create_early_stopping_callback:
early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params)
trainer.callbacks.append(early_stop_callback)

if cfg.create_checkpoint_callback:
configure_checkpointing(
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import pytest
import pytorch_lightning as ptl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import config_utils
from nemo.utils.exp_manager import EarlyStoppingParams


@pytest.fixture()
Expand Down Expand Up @@ -126,3 +128,12 @@ def test_ptl_config(self):
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

@pytest.mark.unit
def test_early_stopping_config(self,):
result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams)
signatures_match, cls_subset, dataclass_subset = result

assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

0 comments on commit 3ed368d

Please sign in to comment.