Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP support in LoRA merge script #8934

Merged
merged 7 commits into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions scripts/nlp_language_modeling/merge_lora_weights/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,34 @@
# limitations under the License.

"""
Merge lora weights into a base GPT LM. Only PP=1 supported so far.
Merge lora weights into a base GPT LM.
Supports any TP and PP the LoRA model is trained on, and no need to specify TP/PP when running this script

Example usage:
python scripts/nlp_language_modeling/merge_lora_weights/merge.py \
trainer.accelerator=gpu \ (use 'cpu' if model cannot fit in memory)
tensor_model_parallel_size=<TP of lora checkpoint> \
pipeline_model_parallel_size=1 \
gpt_model_file=<path to base model nemo file or extracted folder> \
lora_model_path=<path to megatron_gpt_peft_lora_tuning.nemo> \
merged_model_path=<output nemo file>

TP of lora checkpoint can be found by visually examining the output of
`tar -tvf /path/to/lora.nemo`
"""


import os
import tempfile
from typing import Any, Dict
from typing import Any, Dict, List

import torch
from omegaconf import OmegaConf, open_dict
from pytorch_lightning.trainer.trainer import Trainer
from scripts.nlp_language_modeling.merge_lora_weights.convert_lora_parallelism import replace_number_add_offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing import?

from torch.utils.data import DataLoader, Dataset

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.app_state import AppState
Expand Down Expand Up @@ -70,21 +69,35 @@ def __getitem__(self, idx):
return self.sentences[idx]


def load_lora(lora_nemo, tp):
lora_state_dict = {}
def load_lora(lora_nemo):
with tempfile.TemporaryDirectory() as tmpdir:
NLPSaveRestoreConnector._unpack_nemo_file(lora_nemo, tmpdir)
# assert os.path.isdir(lora_extracted_dir), "requires the untar'ed the lora .nemo file"
for i in range(tp):
if tp == 1:
ckpt_file = f"{tmpdir}/model_weights.ckpt"
else:
ckpt_file = f"{tmpdir}/mp_rank_0{i}/model_weights.ckpt"

l = torch.load(ckpt_file, map_location=torch.device('cpu'))
lora_state_dict[i] = l
config_file = f"{tmpdir}/model_config.yaml"
lora_config = OmegaConf.load(config_file)
tp_size = lora_config.tensor_model_parallel_size
pp_size = lora_config.pipeline_model_parallel_size

lora_state_dict = [{}] * tp_size

for pp in range(pp_size):
for tp in range(tp_size):
if tp_size == 1:
ckpt_file = f"{tmpdir}/model_weights.ckpt"
elif pp_size == 1:
ckpt_file = f"{tmpdir}/mp_rank_{tp:02d}/model_weights.ckpt"
else:
ckpt_file = f"{tmpdir}/tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt"

l = torch.load(ckpt_file, map_location=torch.device('cpu'))
if pp == 0:
lora_state_dict[tp] = l
else:
# calculate layer offset
layer_offset = lora_config.num_layers // pp_size * pp
for key, value in l.items():
new_key = replace_number_add_offset(key, layer_offset)
lora_state_dict[tp][new_key] = value

return lora_state_dict, lora_config


Expand All @@ -97,16 +110,16 @@ def fix_for_O2(state_dict):


def merge(
base_model_state_dict: Dict[str, Any], lora_state_dict: Dict[int, Any], tp: int, num_layers: int, mcore: bool,
base_model_state_dict: Dict[str, Any], lora_state_dicts: List[Dict], num_layers: int, mcore: bool,
):
"""
Iterate through all the self_attention.query_key_value projection feedforward weights in all the layers.
Iterate through all the feedforward weights in all the layers.
Collect the corresponding lora weights for each layer and across tp ranks.
Computes the "full rank" weight from the two low-rank weights and add it to the self_attention.query_key_value weight.
Computes the "full rank" weight from the two low-rank weights and add it to the feedforward weight.
Args:
base_model_state_dict: A state_dict for the base model for the current rank.
lora_state_dict: A complete set of weights for the lora model across all tp ranks. They key for this dict is an int tp rank.
tp: the tensor_model_parallel_size for the base_model (and the lora model)
lora_state_dicts: A complete set of weights for the lora model across all tp ranks.
The number of elements in this list is equal to the TP size.
num_layers: the number of layers in the base_model to iterate over.
curr_rank: current tp rank of the base model which is being merged with Lora.
mcore: whether the model uses megatron core.
Expand Down Expand Up @@ -139,13 +152,15 @@ def merge(
key_base = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["base_model_layer"]}'
key_lora_in = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_in"]}'
key_lora_out = f'model.decoder.layers.{nl}.{mcore_layer_to_lora[key]["lora_out"]}'
if key_lora_in in lora_state_dict[0] and key_lora_out in lora_state_dict[0]:
if key in ["attention_qkv", 'mlp_fc1']:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
else:
wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=1).float()

wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
if key_lora_in in lora_state_dicts[0] and key_lora_out in lora_state_dicts[0]:
tp_dim_lora_in = 0 if key in ["attention_qkv", 'mlp_fc1'] else 1

wt_lora_in = torch.cat(
[state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=tp_dim_lora_in
).float()
wt_lora_out = torch.cat(
[state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0
).float()
wt_base = base_model_state_dict[key_base]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_base] = (wt_base.float() + wt_lora.to(wt_base.device)).type_as(wt_base)
Expand All @@ -157,8 +172,8 @@ def merge(
key_lora_in = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight'
key_lora_out = f'model.language_model.encoder.layers.{nl}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight'

wt_lora_in = torch.cat([lora_state_dict[_tp][key_lora_in] for _tp in range(tp)], dim=0).float()
wt_lora_out = torch.cat([lora_state_dict[_tp][key_lora_out] for _tp in range(tp)], dim=0).float()
wt_lora_in = torch.cat([state_dict[key_lora_in] for state_dict in lora_state_dicts], dim=0).float()
wt_lora_out = torch.cat([state_dict[key_lora_out] for state_dict in lora_state_dicts], dim=0).float()
wt_self_attn = base_model_state_dict[key_self_attn_kqv]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_self_attn_kqv] = (
Expand Down Expand Up @@ -233,23 +248,17 @@ def main(cfg) -> None:
raise ValueError("need at least a nemo file or checkpoint dir")

# load the lora weights on cpu for all ranks of the lora model
lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path, cfg.tensor_model_parallel_size)
lora_weights, lora_model_cfg = load_lora(cfg.lora_model_path)

# merge the lora weights with the base model, for this current rank.
merged_weights = merge(
model.state_dict(),
lora_weights,
tp=cfg.tensor_model_parallel_size,
num_layers=model.cfg.num_layers,
mcore=model.mcore_gpt,
)
merged_weights = merge(model.state_dict(), lora_weights, num_layers=model.cfg.num_layers, mcore=model.mcore_gpt)

# load the merged_weights back into the base model, for this current rank.
if model.cfg.megatron_amp_O2:
merged_weights = fix_for_O2(merged_weights)
model.cfg.use_cpu_initialization = (
False # set it back to False otherwise the merged model won't be loaded properly for futher tuning
)

# set use_cpu_initialization back to False otherwise the merged model won't be loaded properly for futher tuning
model.cfg.use_cpu_initialization = False
model.load_state_dict(merged_weights)

if cfg.trainer.accelerator != 'cpu' and model.global_rank == 0:
Expand All @@ -269,7 +278,7 @@ def main(cfg) -> None:
else:
logging.info("Skipping inference validation of merged model since device is 'cpu'.")

model.save_to(cfg.merged_model_path)
model.to(dtype=torch_dtype_from_precision(trainer.precision)).save_to(cfg.merged_model_path)
logging.info(f"saved merged model to {cfg.merged_model_path}")


Expand Down
Loading