Skip to content

Commit

Permalink
Merge branch 'fy/hf2megatron' of ssh://git.sankuai.com/~fengyu05/mega…
Browse files Browse the repository at this point in the history
…tron-deepspeed into fy/hf2megatron
  • Loading branch information
fengyu05 committed Sep 14, 2023
2 parents e9191fb + 1d09a68 commit 95dec64
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tools/convert_checkpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,24 @@ cd /hf/transformers
python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \
/path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt
```

## HF Transformers to Megatron-DeepSpeed (currently only support LLama)

In order to convert llama model from HF Transformers to Megatron-DeepSpeed, you can do this by two steps:

```bash
# 1. Convert llama weight from hf to megatron
python tools/convert_checkpoint/weights2megatron_llama.py \
--out=/path/to/Megatron-Deepspeed/checkpoint/ \
--cache-dir=/path/to/hf/transformers/llama_checkpoint

# 2. Convert Megatron-DeepSpeed checkpoint to distributed version
python3 tools/checkpoint_util.py \
--target-tensor-parallel-size 4 \
--target-pipeline-parallel-size 2 \
--load-dir /path/to/Megatron-Deepspeed/checkpoint/ \
--save-dir /path/to/Megatron-Deepspeed/distribute_checkpoint/ \
--model-type GPT
```


111 changes: 111 additions & 0 deletions tools/convert_checkpoint/merge_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import re
from pathlib import Path
from typing import Optional
from collections import OrderedDict

import torch
from tqdm.auto import tqdm
from transformers import LlamaForCausalLM, AutoTokenizer


scale2emb = {
'7B': 4096,
'13B': 5120,
'30B': 6656,
'65B': 8192,
'70B': 8192,
}


key_to_dim = {
"w1": 0,
"w2": -1,
"w3": 0,
"wo": -1,
"wq": 0,
"wk": 0,
"wv": 0,
"output": 0,
"tok_embeddings": -1,
"ffn_norm": None,
"attention_norm": None,
"norm": None,
"rope": None,
}


def init_merged_ckpt(pth_00, num_pth=8, emb_dim=8192):
merged_ckpt = OrderedDict()
for parameter_name, parameter in pth_00.items():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] is None:
merged_ckpt[parameter_name] = parameter
del parameter
elif key_to_dim[short_name] == 0:
size = parameter.shape[0]
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
merged_ckpt[parameter_name][0 : size, :] = parameter
del parameter
elif key_to_dim[short_name] == -1:
size = parameter.shape[-1]
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
merged_ckpt[parameter_name][:, 0 : size] = parameter
del parameter
return merged_ckpt


def merge_meta_llama(size: int, root_dir: Path):
paths = sorted(path for path in root_dir.iterdir()
if re.match(r"^consolidated\.[0-9]+\.pth$", path.name))
if len(paths) == 1: # no sharded checkpoints, return everything
return torch.load(paths[0], map_location=torch.device("cpu"))

num_pth = len(paths)
for i, ckpt_path in enumerate(tqdm(paths, desc="Merging llama")):
llama_config = torch.load(ckpt_path, map_location=torch.device('cpu'))
if i == 0:
merged_ckpt = init_merged_ckpt(llama_config, num_pth=num_pth,
emb_dim=scale2emb[f"{size}B"])
else:
for parameter_name, parameter in llama_config.items():
short_name = parameter_name.split(".")[-2]
if key_to_dim[short_name] == 0:
size = parameter.shape[0]
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
merged_ckpt[parameter_name][size * i : size * (i + 1), :] = parameter
del parameter
if key_to_dim[short_name] == -1:
size = parameter.shape[-1]
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
merged_ckpt[parameter_name][:, size * i : size * (i + 1)] = parameter
del parameter
del llama_config
return merged_ckpt


def merge_hf_llama(cache_dir: Optional[Path] = None):
# assert version == 2, "Only llama v2 available using huggingface"
model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False)
weights = model.state_dict()
weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight")
weights["norm.weight"] = weights.pop("model.norm.weight")
weights["output.weight"] = weights.pop("lm_head.weight")
for key in list(weights.keys()):
if rmatch := re.match(r"^model\.(layers\.[0-9]+\.)(.+)(\.weight)$", key):
new_key = {
"self_attn.q_proj": "attention.wq",
"self_attn.k_proj": "attention.wk",
"self_attn.v_proj": "attention.wv",
"self_attn.o_proj": "attention.wo",
"mlp.gate_proj": "feed_forward.w1",
"mlp.down_proj": "feed_forward.w2",
"mlp.up_proj": "feed_forward.w3",
"input_layernorm": "attention_norm",
"post_attention_layernorm": "ffn_norm"
}[rmatch.group(2)]
weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key)
return weights, model.config

81 changes: 81 additions & 0 deletions tools/convert_checkpoint/permute_qkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import re
import sys
import os
import shutil
from pathlib import Path
from argparse import ArgumentParser

import torch
from tqdm.auto import tqdm


def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int,
n_heads_kv: int, revert: bool = False) -> torch.Tensor:

def permute(x):
if revert:
return x.view(head_dim//2, 2, dim).transpose(0, 1).reshape(head_dim, dim)
return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim)

head_dim = dim//n_heads
n_qs_per_kv = n_heads//n_heads_kv
n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2)
groups = torch.chunk(qkv_w, n_groups, dim=0)
new = []
for group in groups:
*qs, k, v = torch.split(group, head_dim, dim=0)
assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}"
new += list(map(permute, qs)) + [permute(k), v]
return torch.cat(new, dim=0)


def update_checkpoint(input_dir: Path, output_dir: Path, overwrite_ok: bool = False):
# make sure megatron is importable
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))


# prepare output dir
if output_dir.exists():
if not overwrite_ok:
raise FileExistsError(f"Output directory {output_dir} already exists")
print(f"Removing {output_dir}")
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

# determine realease
with open(input_dir/"latest_checkpointed_iteration.txt") as f:
it = f.read()
print("Updating weights of iteration", it)
with open(output_dir/"latest_checkpointed_iteration.txt", "w+") as f:
f.write(it)
(output_dir/it).mkdir()

# convert weights
for fname in tqdm(list((input_dir/it).iterdir())):
checkpoint = torch.load(fname/"model_optim_rng.pt")
args = checkpoint["args"]
args = (args.hidden_size, args.num_attention_heads,
args.num_attention_heads_kv)
if "transformer" in checkpoint["model"]["language_model"]:
key = "transformer"
attn_key = "attention"
else:
key = "encoder"
attn_key = "self_attention"
states = checkpoint["model"]["language_model"][key]
for name, weight in states.items():
if re.match(rf"^layers\.[0-9]+\.{attn_key}\.query_key_value\.weight$", name):
states[name] = permute_qkv(weight, *args)
(output_dir/it/fname.stem).mkdir()
torch.save(checkpoint, output_dir/it/fname.stem/"model_optim_rng.pt")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-dir", type=Path)
parser.add_argument("--output-dir", type=Path)
parser.add_argument("--overwrite-ok", action="store_true")
args = parser.parse_args()
update_checkpoint(args.input_dir, args.output_dir, args.overwrite_ok)
160 changes: 160 additions & 0 deletions tools/convert_checkpoint/weights2megatron_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import sys
import shutil
from pathlib import Path
from typing import Optional
from argparse import ArgumentParser, Namespace

import torch
from tqdm.auto import trange
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import LlamaConfig

from permute_qkv import permute_qkv
from merge_llama import merge_hf_llama

def llama_to_megatron(weights: dict, llama_config: LlamaConfig = None) -> dict:
def permute(qkv_w):
return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads)

def rearrange_qkv(wq, wk, wv):
wq = torch.split(wq, n_hidden_per_head, dim=0)
wk = torch.split(wk, n_hidden_per_head, dim=0)
wv = torch.split(wv, n_hidden_per_head, dim=0)
assert len(wq) == n_heads
assert len(wk) == n_kv_heads
assert len(wv) == n_kv_heads
n_qs_per_kv = n_heads//n_kv_heads
w_qkv = []
for i in range(n_kv_heads):
w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)]
w_qkv += [wk[i], wv[i]]
return permute(torch.concat(w_qkv))

# config
n_layer = llama_config.num_hidden_layers
hidden = llama_config.hidden_size
n_heads = llama_config.num_attention_heads
n_hidden_per_head = hidden//n_heads
n_kv_heads = llama_config.num_key_value_heads
# weights independent of layers
embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}}
transformer = {"final_layernorm.weight": weights["norm.weight"]}
lm_head = weights["output.weight"]
# get all the other weights
for layer in trange(n_layer, desc="Converting weights"):
prefix = f"layers.{layer}"
# identical weights
transformer[f"{prefix}.attention.dense.weight"] = \
weights[f"{prefix}.attention.wo.weight"]
transformer[f"{prefix}.post_attention_layernorm.weight"] = \
weights[f"{prefix}.ffn_norm.weight"]
transformer[f"{prefix}.input_layernorm.weight"] = \
weights[f"{prefix}.attention_norm.weight"]
transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \
weights[f"{prefix}.feed_forward.w2.weight"]
# concatenate up, gate mlp weights
transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([
weights[f"{prefix}.feed_forward.w3.weight"],
weights[f"{prefix}.feed_forward.w1.weight"]
])
# finally, qkv requires serious manipulation to get right
transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv(
weights[f"{prefix}.attention.wq.weight"],
weights[f"{prefix}.attention.wk.weight"],
weights[f"{prefix}.attention.wv.weight"]
)

# release references to original weights (free mem)
del weights[f"{prefix}.feed_forward.w3.weight"]
del weights[f"{prefix}.feed_forward.w1.weight"]
del weights[f"{prefix}.attention.wq.weight"]
del weights[f"{prefix}.attention.wk.weight"]
del weights[f"{prefix}.attention.wv.weight"]

return {"embedding": embedding, "encoder": transformer,
"lm_head": lm_head}

def main(out: Optional[Path] = None,
cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None):

if megatron_path:
print("Add megatron to os path")
os.path.append(megatron_path)
# get weights from or specified directory
print("Getting llama...")
hf_weights, llama_config = merge_hf_llama(cache_dir)

# convert state dict to be megatron-compatible
megatron_weights = llama_to_megatron(hf_weights, llama_config=llama_config)

# set args
# llama1, llama2
args = {"num_layers": llama_config.num_hidden_layers,
"hidden_size": llama_config.hidden_size,
"num_attention_heads": llama_config.num_attention_heads,
"ffn_hidden_size": llama_config.intermediate_size,
"num_key_value_heads": llama_config.num_key_value_heads,
"parallel_attn": False,
"make_vocab_size_divisible_by": 1,
"glu_activation": "swiglu",
"max_position_embeddings": llama_config.max_length, # should use max_length rather than max_position_embeddings, detail in https://github.com/lm-sys/FastChat/issues/2046#issuecomment-1645265800
"seq_length": llama_config.max_length,
"layernorm_epsilon": llama_config.rms_norm_eps,
# llama args
"padded_vocab_size": llama_config.vocab_size,
"tokenizer_type": "GPTSentencePieceTokenizer",
"no-query-key-layer-scaling": True,
"attention-dropout": 0,
"hidden-dropout": 0,
"use-rotary-position-embeddings": True,
"untie-embeddings-and-output-weights": True,
"swiglu": True,
"normalization": "rmsnorm",
"disable-bias-linear": True,
"add_position_embedding": False,
"add_bias_linear": False,
}
if llama_config.num_key_value_heads:
args.update({"num_attention_heads_kv": llama_config.num_key_value_heads})

args.update({
"tensor_model_parallel_size": 1,
"pipeline_model_parallel_size": 1,
"iteration": 0,
"bias_gelu_fusion": False,
"bias_droput_fusion": False,
})

# save converted weights in specified out
(out/"release"/"mp_rank_00").mkdir(parents=True)
with open(out/"latest_checkpointed_iteration.txt", "w+") as f:
f.write("release")
final_dict = {"iteration": 'release', "model": {"language_model": megatron_weights},
"checkpoint_version": 3.0, "args": Namespace(**args)}
torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt")
print("Saved weights in", out)

tokenizer = LlamaTokenizer.from_pretrained(
cache_dir, cache_dir=cache_dir, local_files_only=True,
)
token_path = out/"tokenizer.model"
vocab_file = tokenizer.vocab_file
shutil.copy(vocab_file, token_path)
print("Saved tokenizer.model in", token_path)
print("Done")

if __name__ == "__main__":
parser = ArgumentParser(description="Convert Huggingface llama weights to "
"megatron-compatible weights")
parser.add_argument("--out", type=Path,
help="Directory to store the megatron weights (as checkpoint)")
parser.add_argument("--cache-dir", type=Path,
help=("Directory to store the huggingface weights, or "
"in case of the llama model, where to look for "
"the consolidated.xx.pth"))
parser.add_argument("--megatron-path", type=Path, default=None,
help="Path where to find megatron code")
args = parser.parse_args()

main(args.out, args.cache_dir, args.megatron_path)

0 comments on commit 95dec64

Please sign in to comment.