Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reset learning rate functionality #9372

Merged
merged 30 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
911c39a
add reset_lr functionality
dimapihtar Jun 3, 2024
7802851
fix reset_lr logic
dimapihtar Jun 3, 2024
e4603f3
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 4, 2024
b2f5eed
Apply isort and black reformatting
dimapihtar Jun 4, 2024
e6e9597
move reset_lr from optim section
dimapihtar Jun 4, 2024
5c4dd14
Apply isort and black reformatting
dimapihtar Jun 4, 2024
4668703
add reset_lr value to config
dimapihtar Jun 4, 2024
de6750a
set reset_lr False by default
dimapihtar Jun 4, 2024
b0b3e17
remove extra line
dimapihtar Jun 4, 2024
7fac9d3
add reset_lr test
dimapihtar Jun 4, 2024
0604dc4
add reset_lr test
dimapihtar Jun 4, 2024
5a2d4c6
remove extra quote
dimapihtar Jun 5, 2024
3cf2211
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 5, 2024
47956e1
add ability to reset schedule's max_steps and decay_steps
dimapihtar Jun 10, 2024
6163909
Apply isort and black reformatting
dimapihtar Jun 10, 2024
df23cc9
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 10, 2024
4119a1d
change scheduler's first step logic when using reset_lr
dimapihtar Jun 10, 2024
92e7cf8
revert config
dimapihtar Jun 10, 2024
d3f03f8
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 10, 2024
5da92cd
fix reset_lr logic
dimapihtar Jun 11, 2024
badde31
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 11, 2024
7cfd47a
Apply isort and black reformatting
dimapihtar Jun 11, 2024
067c264
revert config
dimapihtar Jun 11, 2024
43ccac7
revert config
dimapihtar Jun 11, 2024
0d91dcd
update reset_lr comments
dimapihtar Jun 25, 2024
c2d3765
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 25, 2024
ce4200a
add use cases for reset_lr feature
dimapihtar Jun 25, 2024
e8e555b
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 25, 2024
e6bec29
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 26, 2024
1ebca00
Merge branch 'main' into dpykhtar/reset_lr
dimapihtar Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2630,6 +2630,89 @@ jobs:
# }
# }

L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
timeout-minutes: 10
container:
image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
options:
# --user 0:128
--device=/dev/nvidia0
--gpus all
--shm-size=8g
--env TRANSFORMERS_OFFLINE=0
--env HYDRA_FULL_ERROR=1
--volume /mnt/datadrive/TestData:/home/TestData
steps:
- name: Checkout repository
uses: actions/checkout@v4
- run: |
python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=3 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=3 \
trainer.precision=bf16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
model.tensor_model_parallel_size=2 \
model.megatron_amp_O2=True \
model.optim.name=distributed_fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings

python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=3 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=bf16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \
exp_manager.resume_if_exists=True \
model.reset_lr=True \
model.tensor_model_parallel_size=2 \
model.megatron_amp_O2=True \
model.optim.name=distributed_fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings

rm -rf examples/nlp/language_modeling/gpt_pretrain_results
rm -rf examples/nlp/language_modeling/gpt_index_mappings
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"

L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
Expand Down Expand Up @@ -4296,6 +4379,7 @@ jobs:
- L2_BioMegatron_Bert_NER_Task
- L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_KERPLE_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_Pretraining_and_Resume_Training_PP2
Expand Down
8 changes: 8 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ model:
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.
num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used.

## Reset learning rate schedule.
# 1. reset_lr=True, reset_lr_steps=False. When pre-training an existing checkpoint "from scratch" on a different dataset.
# 2. reset_lr=True, reset_lr_steps=True. When continuing training from an existing checkpoint with the same configuration.
# Learning rate's max_steps and decay_steps will be recalculated as follows: max_steps -= completed_steps, decay_steps -= completed_steps where completed_steps is the number of steps already completed at the checkpoint.
# This will help to reach the min_lr value by the end of training without changing trainer.max_steps.
reset_lr: False # Set to True to reset learning rate to initial learning rate. Only supported with distributed optmizer and megatron_amp_O2.
reset_lr_steps: False # Set to True to adjust learning rate's max_steps and decay_steps by subtracting number of steps already completed at the checkpoint.

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,9 @@ def configure_optimizers(self):
if hasattr(self._cfg.optim, 'sched'):
sched_config = self._cfg.optim.sched
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
optimizer=self._optimizer,
scheduler_config=sched_config,
train_dataloader=self._train_dl,
)

if getattr(self._cfg.optim, 'sched', None) is not None and self._scheduler is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.inference_params = None

# Reset learning rate params
self.if_init_step = True
dimapihtar marked this conversation as resolved.
Show resolved Hide resolved
self.reset_lr = self.cfg.get('reset_lr', False)
self.reset_lr_steps = self.cfg.get('reset_lr_steps', False)
if self.reset_lr and (not self.with_distributed_adam or not self.megatron_amp_O2):
raise ValueError(
'Learning rate reset feature is only supported with the distributed optmizer and megatron_amp_O2 for now.'
)

# default to false since this doesn't work with sequence parallelism currently
self.use_loss_mask = self.cfg.get('use_loss_mask', False)

Expand Down Expand Up @@ -763,6 +772,20 @@ def training_step(self, dataloader_iter):
if self.initialize_ub:
self.initialize_ub_func()

# Reset learning rate
if self.if_init_step and self.reset_lr:
num_groups = len(self._optimizer.param_groups)
for group in range(num_groups):
self._optimizer.param_groups[group]['lr'] = (
0.0 if self.cfg.optim.sched.warmup_steps > 0 else self.cfg.optim.lr
)
self._optimizer.param_groups[0]['reset_lr'] = {
'num_steps': self.trainer.global_step,
'reset_lr_steps': True if self.reset_lr_steps else False,
'if_init_step': self.if_init_step,
}
self.if_init_step = False

if self.rampup_batch_size:
num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
Expand Down
35 changes: 30 additions & 5 deletions nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ class SquareRootConstantPolicy(_LRScheduler):
"""

def __init__(
self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
self,
optimizer,
*,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (
constant_steps is not None and constant_ratio is not None
Expand All @@ -114,7 +121,7 @@ def __init__(
else:
self.constant_steps = 0

self.constant_lr = 1 / (constant_steps ** 0.5)
self.constant_lr = 1 / (constant_steps**0.5)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)

Expand Down Expand Up @@ -280,6 +287,16 @@ def get_lr(self):

step = self.last_epoch

# Reset learning rate
if 'reset_lr' in self.optimizer.param_groups[0].keys():
reset_lr = self.optimizer.param_groups[0]['reset_lr']
num_steps = reset_lr['num_steps']
step -= num_steps
if reset_lr['if_init_step'] and reset_lr['reset_lr_steps']:
self.decay_steps -= num_steps
self.max_steps -= num_steps
self.optimizer.param_groups[0]['reset_lr']['if_init_step'] = False

# Warmup steps
if self.warmup_steps > 0 and step <= self.warmup_steps:
return self._get_warmup_lr(step)
Expand Down Expand Up @@ -364,7 +381,7 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):

def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr):
# hold_steps = total number of steps to hold the LR, not the warmup + hold steps.
T_warmup_decay = max(1, warmup_steps ** decay_rate)
T_warmup_decay = max(1, warmup_steps**decay_rate)
T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
lr = (initial_lr * T_warmup_decay) / T_hold_decay
lr = max(lr, min_lr)
Expand Down Expand Up @@ -453,7 +470,15 @@ def _get_linear_warmup_with_cosine_annealing_lr(self, step):

class NoamAnnealing(_LRScheduler):
def __init__(
self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
self,
optimizer,
*,
d_model,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
self._normalize = d_model ** (-0.5)
assert not (
Expand Down Expand Up @@ -593,7 +618,7 @@ def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs)
super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)

def _get_lr(self, step):
return [1 / (step ** 0.5) for _ in self.base_lrs]
return [1 / (step**0.5) for _ in self.base_lrs]


class PolynomialDecayAnnealing(WarmupPolicy):
Expand Down
Loading