Skip to content

Commit

Permalink
Sequence Parallel for LoRA (#8369)
Browse files Browse the repository at this point in the history
* support lora + sequence parallel

Signed-off-by: Chen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more comments

Signed-off-by: Chen Cui <chcui@nvidia.com>

* add lora SP CI test

Signed-off-by: Chen Cui <chcui@nvidia.com>

* support lora for all linear modules as in #7988

Signed-off-by: Chen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cuichenx and pre-commit-ci[bot] authored Feb 23, 2024
1 parent f19b9a5 commit 5b38a7e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
36 changes: 36 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -4291,6 +4291,42 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf /home/TestData/nlp/lora_tuning_tp2"
}
}
stage('L2: Megatron GPT PEFT Lora TP=2 SP') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "rm -rf /home/TestData/nlp/lora_tuning_tp2_sp"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.max_epochs=9999 \
trainer.max_steps=3 \
trainer.val_check_interval=3 \
++trainer.limit_val_batches=2 \
trainer.precision=16 \
exp_manager.exp_dir=/home/TestData/nlp/lora_tuning_tp2 \
model.pipeline_model_parallel_size=1 \
model.tensor_model_parallel_size=2 \
model.sequence_parallel=true \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/TP2/megatron_gpt_tp2.nemo \
model.peft.peft_scheme='lora' \
model.answer_only_loss=True \
model.micro_batch_size=1 \
model.global_batch_size=1 \
model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.train_ds.concat_sampling_probabilities=[1.0] \
model.data.train_ds.num_workers=0 \
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]"
sh "rm -rf /home/TestData/nlp/lora_tuning_tp2_sp"
}
}
stage('L2: Megatron GPT Eval') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def _append_sequence_parallel_module_grads(self, module, grads):
# (@adithyare) adapter training now extends MegatronGPTModel
# so we have to add this check here to ensure we do not
# perform all_reduce when grad is None.
# grad can be None when performing PeFT training.
# grad can be None when performing PEFT training.
if sequence_parallel_param and param.requires_grad:
if self.megatron_amp_O2:
grad = param.main_grad
Expand All @@ -775,7 +775,9 @@ def allreduce_sequence_parallel_gradients(self):
self._append_sequence_parallel_module_grads(module, grads)
else:
self._append_sequence_parallel_module_grads(self.model, grads)

if not grads:
# may be empty for PEFT training
return
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group())
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
try:
from megatron.core import ModelParallelConfig
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
)

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -147,11 +151,13 @@ def __init__(
self.activation = activation_registry[activation]()
self.norm_position = norm_position
self.dim = dim

self.input_is_parallel = input_is_parallel
# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
if model_parallel_config is None:
model_parallel_config = ModelParallelConfig()
self._sequence_parallel = model_parallel_config.sequence_parallel
model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer

if input_is_parallel:
self.linear_in = RowParallelLinear(
Expand Down Expand Up @@ -219,6 +225,9 @@ def __init__(
# Setup adapter strategy
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy())

# revert config change in case it is read elsewhere
model_parallel_config.sequence_parallel = self._sequence_parallel

def _get_init_fn(self, init_method: str):
if init_method == 'xavier':
init_fn = init.xavier_normal_
Expand All @@ -240,10 +249,24 @@ def forward(self, x):

if self.norm_position == 'pre':
x = self.layer_norm(x)
if self._sequence_parallel and not self.input_is_parallel:
# for attention_qkv and linear_fc1
# layernorm before lora is impacted by sequence parallel,
# hence seq dim need to be gathered right before lora linear layers
# this function also handles the backward pass correctly
x = gather_from_sequence_parallel_region(x)

x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term.
x = self.activation(x)
x, _ = self.linear_out(x)

if self._sequence_parallel and self.input_is_parallel:
# for attention_dense and linear_fc2
# layernorm after lora is impacted by sequence parallel,
# hence seq dim need to be scattered right after lora linear layers
# this function also handles the backward pass correctly
x = scatter_to_sequence_parallel_region(x)

if self.norm_position == 'post':
x = self.layer_norm(x)

Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from omegaconf import DictConfig

from nemo.utils import logging

try:
from nemo.collections.nlp.modules.common.megatron.adapters.mcore_mixins import (
MCoreGPTEmbeddingMixin,
Expand Down Expand Up @@ -148,6 +150,12 @@ def __init__(self, cfg):
)
name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)]
else:
logging.error(
f"Unrecognized target_module string: {module}.\n"
f"The possible options are: {list(PEFT_MODULE_MAP.values())}"
)
exit(1)

self.name_key_to_mcore_mixins = name_key_to_mcore_mixins
super().__init__(lora_cfg, name_key_to_cfg)
Expand Down

0 comments on commit 5b38a7e

Please sign in to comment.