Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding early stop callback to ptuning #6028

Merged
merged 42 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2b95406
patch to allow using tokenizers without additional_special_tokens_ids…
arendu Dec 15, 2022
c131a90
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 16, 2022
9e15c3a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Dec 20, 2022
d0e3669
merge main
arendu Jan 5, 2023
0a19a5a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
ec3d57b
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
64e36ba
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 6, 2023
5bfde7e
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 7, 2023
b04b145
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 10, 2023
b1906ab
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 11, 2023
9795062
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 12, 2023
0f83085
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 19, 2023
ee4dd1a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 20, 2023
53ba0b2
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 23, 2023
a6aee2a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 24, 2023
33442d4
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 30, 2023
8e6c5c9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Jan 31, 2023
efd263c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 8, 2023
ecfda4f
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 10, 2023
15aee0c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
arendu Feb 15, 2023
3fe8e34
early stop callback for prompt/p tuning
arendu Feb 15, 2023
0d2666c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2023
1af3e79
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
ecbb118
update
arendu Feb 15, 2023
8af4d3e
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
b0d9ea4
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
8651935
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
2b569e5
added exp manager config for early stop
arendu Feb 15, 2023
d6f48d1
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
409d94e
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
a10284d
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
f155911
pushed logic for creating early stopping inside exp manager
arendu Feb 15, 2023
858d46a
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 15, 2023
6ae6188
pushed logic for creating early stopping inside exp manager
arendu Feb 15, 2023
0aefe59
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 15, 2023
2f1111d
minor updates and added dataclass check
arendu Feb 16, 2023
60e0d25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
2b8d254
more args
arendu Feb 16, 2023
cac3df9
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 16, 2023
2f35842
Merge branch 'main' into adithyare/early_stop_ptuning
arendu Feb 16, 2023
a176efb
more args
arendu Feb 16, 2023
c31d4aa
Merge branch 'adithyare/early_stop_ptuning' of https://github.com/NVI…
arendu Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ trainer:
gradient_clip_val: 1.0
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
benchmark: False



exp_manager:
Expand All @@ -36,6 +37,14 @@ exp_manager:
filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
arendu marked this conversation as resolved.
Show resolved Hide resolved
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True


model:
seed: 1234
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ exp_manager:
filename: "megatron_t5_prompt_tune--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}"
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
arendu marked this conversation as resolved.
Show resolved Hide resolved
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True

model:
seed: 1234
Expand Down
24 changes: 24 additions & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.timer import Interval, Timer
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.loops import TrainingEpochLoop
Expand Down Expand Up @@ -69,6 +70,21 @@ class CheckpointMisconfigurationError(NeMoBaseException):
""" Raised when a mismatch between trainer.callbacks and exp_manager occurs"""


@dataclass
class EarlyStoppingParams:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
monitor: str = "val_loss" # The metric that early stopping should consider.
mode: str = "min" # inform early stopping whether to look for increase or decrease in monitor.
min_delta: float = 0.001 # smallest change to consider as improvement.
patience: int = 10 # how many (continuous) validation cycles to wait with no improvement and stopping training.
verbose: bool = True
strict: bool = True
check_finite: bool = True
stopping_threshold: Optional[float] = None
divergence_threshold: Optional[float] = None
check_on_train_epoch_end: Optional[bool] = None
log_rank_zero_only: bool = False


@dataclass
class CallbackParams:
filepath: Optional[str] = None # Deprecated
Expand Down Expand Up @@ -153,6 +169,8 @@ class ExpManagerConfig:
# Checkpointing parameters
create_checkpoint_callback: Optional[bool] = True
checkpoint_callback_params: Optional[CallbackParams] = CallbackParams()
create_early_stopping_callback: Optional[bool] = False
early_stopping_callback_params: Optional[EarlyStoppingParams] = EarlyStoppingParams()
# Additional exp_manager arguments
files_to_copy: Optional[List[str]] = None
# logs timing of train/val/test steps
Expand Down Expand Up @@ -272,6 +290,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most
recent checkpoint under ``*last.ckpt``, and the final checkpoint after training completes under ``*end.ckpt``.
Defaults to True.
- create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False.
See EarlyStoppingParams dataclass above.
- files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
copies no files.
- log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False.
Expand Down Expand Up @@ -420,6 +440,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
)
trainer.callbacks.append(ema_callback)

if cfg.create_early_stopping_callback:
early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params)
trainer.callbacks.append(early_stop_callback)

if cfg.create_checkpoint_callback:
configure_checkpointing(
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import pytest
import pytorch_lightning as ptl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import config_utils
from nemo.utils.exp_manager import EarlyStoppingParams


@pytest.fixture()
Expand Down Expand Up @@ -126,3 +128,12 @@ def test_ptl_config(self):
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

@pytest.mark.unit
def test_early_stopping_config(self,):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994 I can't figure out what is wrong here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to have worked out

result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams)
signatures_match, cls_subset, dataclass_subset = result

assert signatures_match
assert cls_subset is None
assert dataclass_subset is None