Skip to content

Commit

Permalink
add reset learning rate functionality (#9372)
Browse files Browse the repository at this point in the history
* add reset_lr functionality

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix reset_lr logic

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* move reset_lr from optim section

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* add reset_lr value to config

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* set reset_lr False by default

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* remove extra line

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add reset_lr test

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add reset_lr test

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* remove extra quote

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add ability to reset schedule's max_steps and decay_steps

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* change scheduler's first step logic when using reset_lr

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert config

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix reset_lr logic

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* revert config

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert config

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update reset_lr comments

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add use cases for reset_lr feature

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: dimapihtar <dpihtar@gmail.com>
Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>
Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com>
  • Loading branch information
dimapihtar and dimapihtar authored Jun 27, 2024
1 parent f49f2e9 commit 397ed6a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 6 deletions.
84 changes: 84 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2630,6 +2630,89 @@ jobs:
# }
# }

L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
timeout-minutes: 10
container:
image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
options:
# --user 0:128
--device=/dev/nvidia0
--gpus all
--shm-size=8g
--env TRANSFORMERS_OFFLINE=0
--env HYDRA_FULL_ERROR=1
--volume /mnt/datadrive/TestData:/home/TestData
steps:
- name: Checkout repository
uses: actions/checkout@v4
- run: |
python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=3 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=3 \
trainer.precision=bf16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
model.tensor_model_parallel_size=2 \
model.megatron_amp_O2=True \
model.optim.name=distributed_fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings
python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=3 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=bf16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
exp_manager.resume_if_exists=True \
model.reset_lr=True \
model.tensor_model_parallel_size=2 \
model.megatron_amp_O2=True \
model.optim.name=distributed_fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings
rm -rf examples/nlp/language_modeling/gpt_pretrain_results
rm -rf examples/nlp/language_modeling/gpt_index_mappings
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"

L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
Expand Down Expand Up @@ -4296,6 +4379,7 @@ jobs:
- L2_BioMegatron_Bert_NER_Task
- L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_KERPLE_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_Pretraining_and_Resume_Training_PP2
Expand Down
8 changes: 8 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ model:
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.
num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used.

## Reset learning rate schedule.
# 1. reset_lr=True, reset_lr_steps=False. When pre-training an existing checkpoint "from scratch" on a different dataset.
# 2. reset_lr=True, reset_lr_steps=True. When continuing training from an existing checkpoint with the same configuration.
# Learning rate's max_steps and decay_steps will be recalculated as follows: max_steps -= completed_steps, decay_steps -= completed_steps where completed_steps is the number of steps already completed at the checkpoint.
# This will help to reach the min_lr value by the end of training without changing trainer.max_steps.
reset_lr: False # Set to True to reset learning rate to initial learning rate. Only supported with distributed optmizer and megatron_amp_O2.
reset_lr_steps: False # Set to True to adjust learning rate's max_steps and decay_steps by subtracting number of steps already completed at the checkpoint.

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,9 @@ def configure_optimizers(self):
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
optimizer=self._optimizer,
scheduler_config=sched_config,
train_dataloader=self._train_dl,
)

if getattr(self._cfg.optim, 'sched', None) is not None and self._scheduler is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.inference_params = None

# Reset learning rate params
self.if_init_step = True
self.reset_lr = self.cfg.get('reset_lr', False)
self.reset_lr_steps = self.cfg.get('reset_lr_steps', False)
if self.reset_lr and (not self.with_distributed_adam or not self.megatron_amp_O2):
raise ValueError(
'Learning rate reset feature is only supported with the distributed optmizer and megatron_amp_O2 for now.'
)

# default to false since this doesn't work with sequence parallelism currently
self.use_loss_mask = self.cfg.get('use_loss_mask', False)

Expand Down Expand Up @@ -763,6 +772,20 @@ def training_step(self, dataloader_iter):
if self.initialize_ub:
self.initialize_ub_func()

# Reset learning rate
if self.if_init_step and self.reset_lr:
num_groups = len(self._optimizer.param_groups)
for group in range(num_groups):
self._optimizer.param_groups[group]['lr'] = (
0.0 if self.cfg.optim.sched.warmup_steps > 0 else self.cfg.optim.lr
)
self._optimizer.param_groups[0]['reset_lr'] = {
'num_steps': self.trainer.global_step,
'reset_lr_steps': True if self.reset_lr_steps else False,
'if_init_step': self.if_init_step,
}
self.if_init_step = False

if self.rampup_batch_size:
num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
Expand Down
35 changes: 30 additions & 5 deletions nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ class SquareRootConstantPolicy(_LRScheduler):
"""

def __init__(
self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
self,
optimizer,
*,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (
constant_steps is not None and constant_ratio is not None
Expand All @@ -114,7 +121,7 @@ def __init__(
else:
self.constant_steps = 0

self.constant_lr = 1 / (constant_steps ** 0.5)
self.constant_lr = 1 / (constant_steps**0.5)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)

Expand Down Expand Up @@ -280,6 +287,16 @@ def get_lr(self):

step = self.last_epoch

# Reset learning rate
if 'reset_lr' in self.optimizer.param_groups[0].keys():
reset_lr = self.optimizer.param_groups[0]['reset_lr']
num_steps = reset_lr['num_steps']
step -= num_steps
if reset_lr['if_init_step'] and reset_lr['reset_lr_steps']:
self.decay_steps -= num_steps
self.max_steps -= num_steps
self.optimizer.param_groups[0]['reset_lr']['if_init_step'] = False

# Warmup steps
if self.warmup_steps > 0 and step <= self.warmup_steps:
return self._get_warmup_lr(step)
Expand Down Expand Up @@ -364,7 +381,7 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):

def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr):
# hold_steps = total number of steps to hold the LR, not the warmup + hold steps.
T_warmup_decay = max(1, warmup_steps ** decay_rate)
T_warmup_decay = max(1, warmup_steps**decay_rate)
T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
lr = (initial_lr * T_warmup_decay) / T_hold_decay
lr = max(lr, min_lr)
Expand Down Expand Up @@ -453,7 +470,15 @@ def _get_linear_warmup_with_cosine_annealing_lr(self, step):

class NoamAnnealing(_LRScheduler):
def __init__(
self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
self,
optimizer,
*,
d_model,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
self._normalize = d_model ** (-0.5)
assert not (
Expand Down Expand Up @@ -593,7 +618,7 @@ def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs)
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

def _get_lr(self, step):
return [1 / (step ** 0.5) for _ in self.base_lrs]
return [1 / (step**0.5) for _ in self.base_lrs]


class PolynomialDecayAnnealing(WarmupPolicy):
Expand Down

0 comments on commit 397ed6a

Please sign in to comment.