From a9d58f4d0882c87912911b2fc91fd96ae79aca72 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 18 Apr 2024 14:07:19 -0400 Subject: [PATCH] PP support in LoRA merge script (#8934) * initial commit Signed-off-by: Chen Cui * enable pp support for merge script and fix output precision Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove incomplete script for next release Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Eric Harper --- .../merge_lora_weights/merge.py | 95 ++++++++++--------- 1 file changed, 52 insertions(+), 43 deletions(-) diff --git a/scripts/nlp_language_modeling/merge_lora_weights/merge.py b/scripts/nlp_language_modeling/merge_lora_weights/merge.py index ccdb433630da..14fe3db80690 100644 --- a/scripts/nlp_language_modeling/merge_lora_weights/merge.py +++ b/scripts/nlp_language_modeling/merge_lora_weights/merge.py @@ -14,35 +14,34 @@ # limitations under the License. """ -Merge lora weights into a base GPT LM. Only PP=1 supported so far. +Merge lora weights into a base GPT LM. +Supports any TP and PP the LoRA model is trained on, and no need to specify TP/PP when running this script Example usage: python scripts/nlp_language_modeling/merge_lora_weights/merge.py \ trainer.accelerator=gpu \ (use 'cpu' if model cannot fit in memory) - tensor_model_parallel_size= \ - pipeline_model_parallel_size=1 \ gpt_model_file= \ lora_model_path= \ merged_model_path= -TP of lora checkpoint can be found by visually examining the output of -`tar -tvf /path/to/lora.nemo` """ import os import tempfile -from typing import Any, Dict +from typing import Any, Dict, List import torch from omegaconf import OmegaConf, open_dict from pytorch_lightning.trainer.trainer import Trainer +from scripts.nlp_language_modeling.merge_lora_weights.convert_lora_parallelism import replace_number_add_offset from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.app_state import AppState @@ -70,21 +69,35 @@ def __getitem__(self, idx): return self.sentences[idx] -def load_lora(lora_nemo, tp): - lora_state_dict = {} +def load_lora(lora_nemo): with tempfile.TemporaryDirectory() as tmpdir: NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir) - # assert os.path.isdir(lora_extracted_dir), "requires the untar'ed the lora .nemo file" - for i in range(tp): - if tp == 1: - ckpt_file = f"{tmpdir}/model_weights.ckpt" - else: - ckpt_file = f"{tmpdir}/mp_rank_0{i}/model_weights.ckpt" - - l = torch.load(ckpt_file, map_location=torch.device('cpu')) - lora_state_dict[i] = l config_file = f"{tmpdir}/model_config.yaml" lora_config = OmegaConf.load(config_file) + tp_size = lora_config.tensor_model_parallel_size + pp_size = lora_config.pipeline_model_parallel_size + + lora_state_dict = [{}] * tp_size + + for pp in range(pp_size): + for tp in range(tp_size): + if tp_size == 1: + ckpt_file = f"{tmpdir}/model_weights.ckpt" + elif pp_size == 1: + ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt" + else: + ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt" + + l = torch.load(ckpt_file, map_location=torch.device('cpu')) + if pp == 0: + lora_state_dict[tp] = l + else: + # calculate layer offset + layer_offset = lora_config.num_layers // pp_size * pp + for key, value in l.items(): + new_key = replace_number_add_offset(key, layer_offset) + lora_state_dict[tp][new_key] = value + return lora_state_dict, lora_config @@ -97,16 +110,16 @@ def fix_for_O2(state_dict): def merge( - base_model_state_dict: Dict[str, Any], lora_state_dict: Dict[int, Any], tp: int, num_layers: int, mcore: bool, + base_model_state_dict: Dict[str, Any], lora_state_dicts: List[Dict], num_layers: int, mcore: bool, ): """ - Iterate through all the self_attention.query_key_value projection feedforward weights in all the layers. + Iterate through all the feedforward weights in all the layers. Collect the corresponding lora weights for each layer and across tp ranks. - Computes the "full rank" weight from the two low-rank weights and add it to the self_attention.query_key_value weight. + Computes the "full rank" weight from the two low-rank weights and add it to the feedforward weight. Args: base_model_state_dict: A state_dict for the base model for the current rank. - lora_state_dict: A complete set of weights for the lora model across all tp ranks. They key for this dict is an int tp rank. - tp: the tensor_model_parallel_size for the base_model (and the lora model) + lora_state_dicts: A complete set of weights for the lora model across all tp ranks. + The number of elements in this list is equal to the TP size. num_layers: the number of layers in the base_model to iterate over. curr_rank: current tp rank of the base model which is being merged with Lora. mcore: whether the model uses megatron core. @@ -139,13 +152,15 @@ def merge( key_base = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["base_model_layer"]}' key_lora_in = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_in"]}' key_lora_out = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_out"]}' - if key_lora_in in lora_state_dict[0] and key_lora_out in lora_state_dict[0]: - if key in ["attention_qkv", 'mlp_fc1']: - wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float() - else: - wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=1).float() - - wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float() + if key_lora_in in lora_state_dicts[0] and key_lora_out in lora_state_dicts[0]: + tp_dim_lora_in = 0 if key in ["attention_qkv", 'mlp_fc1'] else 1 + + wt_lora_in = torch.cat( + [state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=tp_dim_lora_in + ).float() + wt_lora_out = torch.cat( + [state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0 + ).float() wt_base = base_model_state_dict[key_base] wt_lora = wt_lora_out @ wt_lora_in base_model_state_dict[key_base] = (wt_base.float() + wt_lora.to(wt_base.device)).type_as(wt_base) @@ -157,8 +172,8 @@ def merge( key_lora_in = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight' key_lora_out = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight' - wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float() - wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float() + wt_lora_in = torch.cat([state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=0).float() + wt_lora_out = torch.cat([state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0).float() wt_self_attn = base_model_state_dict[key_self_attn_kqv] wt_lora = wt_lora_out @ wt_lora_in base_model_state_dict[key_self_attn_kqv] = ( @@ -233,23 +248,17 @@ def main(cfg) -> None: raise ValueError("need at least a nemo file or checkpoint dir") # load the lora weights on cpu for all ranks of the lora model - lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path, cfg.tensor_model_parallel_size) + lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path) # merge the lora weights with the base model, for this current rank. - merged_weights = merge( - model.state_dict(), - lora_weights, - tp=cfg.tensor_model_parallel_size, - num_layers=model.cfg.num_layers, - mcore=model.mcore_gpt, - ) + merged_weights = merge(model.state_dict(), lora_weights, num_layers=model.cfg.num_layers, mcore=model.mcore_gpt) # load the merged_weights back into the base model, for this current rank. if model.cfg.megatron_amp_O2: merged_weights = fix_for_O2(merged_weights) - model.cfg.use_cpu_initialization = ( - False # set it back to False otherwise the merged model won't be loaded properly for futher tuning - ) + + # set use_cpu_initialization back to False otherwise the merged model won't be loaded properly for futher tuning + model.cfg.use_cpu_initialization = False model.load_state_dict(merged_weights) if cfg.trainer.accelerator != 'cpu' and model.global_rank == 0: @@ -269,7 +278,7 @@ def main(cfg) -> None: else: logging.info("Skipping inference validation of merged model since device is 'cpu'.") - model.save_to(cfg.merged_model_path) + model.to(dtype=torch_dtype_from_precision(trainer.precision)).save_to(cfg.merged_model_path) logging.info(f"saved merged model to {cfg.merged_model_path}")