Skip to content

Commit

Permalink
PP support in LoRA merge script (#8934)
Browse files Browse the repository at this point in the history
* initial commit

Signed-off-by: Chen Cui <chcui@nvidia.com>

* enable pp support for merge script and fix output precision

Signed-off-by: Chen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove incomplete script for next release

Signed-off-by: Chen Cui <chcui@nvidia.com>

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <adithya.r@gmail.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
4 people authored Apr 18, 2024
1 parent 9507f08 commit a9d58f4
Showing 1 changed file with 52 additions and 43 deletions.
95 changes: 52 additions & 43 deletions scripts/nlp_language_modeling/merge_lora_weights/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,34 @@
# limitations under the License.

"""
Merge lora weights into a base GPT LM. Only PP=1 supported so far.
Merge lora weights into a base GPT LM.
Supports any TP and PP the LoRA model is trained on, and no need to specify TP/PP when running this script
Example usage:
python scripts/nlp_language_modeling/merge_lora_weights/merge.py \
trainer.accelerator=gpu \ (use 'cpu' if model cannot fit in memory)
tensor_model_parallel_size=<TP of lora checkpoint> \
pipeline_model_parallel_size=1 \
gpt_model_file=<path to base model nemo file or extracted folder> \
lora_model_path=<path to megatron_gpt_peft_lora_tuning.nemo> \
merged_model_path=<output nemo file>
TP of lora checkpoint can be found by visually examining the output of
`tar -tvf /path/to/lora.nemo`
"""


import os
import tempfile
from typing import Any, Dict
from typing import Any, Dict, List

import torch
from omegaconf import OmegaConf, open_dict
from pytorch_lightning.trainer.trainer import Trainer
from scripts.nlp_language_modeling.merge_lora_weights.convert_lora_parallelism import replace_number_add_offset
from torch.utils.data import DataLoader, Dataset

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.app_state import AppState
Expand Down Expand Up @@ -70,21 +69,35 @@ def __getitem__(self, idx):
return self.sentences[idx]


def load_lora(lora_nemo, tp):
lora_state_dict = {}
def load_lora(lora_nemo):
with tempfile.TemporaryDirectory() as tmpdir:
NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir)
# assert os.path.isdir(lora_extracted_dir), "requires the untar'ed the lora .nemo file"
for i in range(tp):
if tp == 1:
ckpt_file = f"{tmpdir}/model_weights.ckpt"
else:
ckpt_file = f"{tmpdir}/mp_rank_0{i}/model_weights.ckpt"

l = torch.load(ckpt_file, map_location=torch.device('cpu'))
lora_state_dict[i] = l
config_file = f"{tmpdir}/model_config.yaml"
lora_config = OmegaConf.load(config_file)
tp_size = lora_config.tensor_model_parallel_size
pp_size = lora_config.pipeline_model_parallel_size

lora_state_dict = [{}] * tp_size

for pp in range(pp_size):
for tp in range(tp_size):
if tp_size == 1:
ckpt_file = f"{tmpdir}/model_weights.ckpt"
elif pp_size == 1:
ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt"
else:
ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt"

l = torch.load(ckpt_file, map_location=torch.device('cpu'))
if pp == 0:
lora_state_dict[tp] = l
else:
# calculate layer offset
layer_offset = lora_config.num_layers // pp_size * pp
for key, value in l.items():
new_key = replace_number_add_offset(key, layer_offset)
lora_state_dict[tp][new_key] = value

return lora_state_dict, lora_config


Expand All @@ -97,16 +110,16 @@ def fix_for_O2(state_dict):


def merge(
base_model_state_dict: Dict[str, Any], lora_state_dict: Dict[int, Any], tp: int, num_layers: int, mcore: bool,
base_model_state_dict: Dict[str, Any], lora_state_dicts: List[Dict], num_layers: int, mcore: bool,
):
"""
Iterate through all the self_attention.query_key_value projection feedforward weights in all the layers.
Iterate through all the feedforward weights in all the layers.
Collect the corresponding lora weights for each layer and across tp ranks.
Computes the "full rank" weight from the two low-rank weights and add it to the self_attention.query_key_value weight.
Computes the "full rank" weight from the two low-rank weights and add it to the feedforward weight.
Args:
base_model_state_dict: A state_dict for the base model for the current rank.
lora_state_dict: A complete set of weights for the lora model across all tp ranks. They key for this dict is an int tp rank.
tp: the tensor_model_parallel_size for the base_model (and the lora model)
lora_state_dicts: A complete set of weights for the lora model across all tp ranks.
The number of elements in this list is equal to the TP size.
num_layers: the number of layers in the base_model to iterate over.
curr_rank: current tp rank of the base model which is being merged with Lora.
mcore: whether the model uses megatron core.
Expand Down Expand Up @@ -139,13 +152,15 @@ def merge(
key_base = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["base_model_layer"]}'
key_lora_in = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_in"]}'
key_lora_out = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_out"]}'
if key_lora_in in lora_state_dict[0] and key_lora_out in lora_state_dict[0]:
if key in ["attention_qkv", 'mlp_fc1']:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
else:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=1).float()

wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
if key_lora_in in lora_state_dicts[0] and key_lora_out in lora_state_dicts[0]:
tp_dim_lora_in = 0 if key in ["attention_qkv", 'mlp_fc1'] else 1

wt_lora_in = torch.cat(
[state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=tp_dim_lora_in
).float()
wt_lora_out = torch.cat(
[state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0
).float()
wt_base = base_model_state_dict[key_base]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_base] = (wt_base.float() + wt_lora.to(wt_base.device)).type_as(wt_base)
Expand All @@ -157,8 +172,8 @@ def merge(
key_lora_in = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight'
key_lora_out = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight'

wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
wt_lora_in = torch.cat([state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=0).float()
wt_lora_out = torch.cat([state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0).float()
wt_self_attn = base_model_state_dict[key_self_attn_kqv]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_self_attn_kqv] = (
Expand Down Expand Up @@ -233,23 +248,17 @@ def main(cfg) -> None:
raise ValueError("need at least a nemo file or checkpoint dir")

# load the lora weights on cpu for all ranks of the lora model
lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path, cfg.tensor_model_parallel_size)
lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path)

# merge the lora weights with the base model, for this current rank.
merged_weights = merge(
model.state_dict(),
lora_weights,
tp=cfg.tensor_model_parallel_size,
num_layers=model.cfg.num_layers,
mcore=model.mcore_gpt,
)
merged_weights = merge(model.state_dict(), lora_weights, num_layers=model.cfg.num_layers, mcore=model.mcore_gpt)

# load the merged_weights back into the base model, for this current rank.
if model.cfg.megatron_amp_O2:
merged_weights = fix_for_O2(merged_weights)
model.cfg.use_cpu_initialization = (
False # set it back to False otherwise the merged model won't be loaded properly for futher tuning
)

# set use_cpu_initialization back to False otherwise the merged model won't be loaded properly for futher tuning
model.cfg.use_cpu_initialization = False
model.load_state_dict(merged_weights)

if cfg.trainer.accelerator != 'cpu' and model.global_rank == 0:
Expand All @@ -269,7 +278,7 @@ def main(cfg) -> None:
else:
logging.info("Skipping inference validation of merged model since device is 'cpu'.")

model.save_to(cfg.merged_model_path)
model.to(dtype=torch_dtype_from_precision(trainer.precision)).save_to(cfg.merged_model_path)
logging.info(f"saved merged model to {cfg.merged_model_path}")


Expand Down

0 comments on commit a9d58f4

Please sign in to comment.