-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_fsdp_to_hf.py
executable file
·38 lines (33 loc) · 1.44 KB
/
convert_fsdp_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Convert hf model to checkpoint consummable by fsdp"""
import argparse
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import torch.distributed._shard.checkpoint as dist_cp
import torch
import transformers
from utils import add_padding_token
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load_path", type=str, default="llama/7B_sharded")
parser.add_argument("--save_path", type=str, default="llama/7B_new_hf")
parser.add_argument("--added_tokens", type=int, default=1, help="Number of tokens added to the model that need to be ")
parser.add_argument("--config_path", type=str, default="llama/7B_hf")
args = parser.parse_args()
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.config_path,
padding_side="left",
use_fast=False,
)
add_padding_token(tokenizer)
tokenizer.save_pretrained(args.save_path)
model_config = transformers.AutoConfig.from_pretrained(args.config_path)
model = LlamaForCausalLM(model_config).bfloat16()
if args.added_tokens > 0:
model.resize_token_embeddings(model.config.vocab_size + args.added_tokens)
state_dict = model.state_dict()
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(args.load_path),
no_dist=True
)
model.load_state_dict(state_dict)
model.save_pretrained(args.save_path)