Skip to content

Commit 9500978

Browse files
Wesley Truongwesleytruong
authored andcommitted
added model definition converison for llama3
1 parent ad7f644 commit 9500978

File tree

5 files changed

+295
-15
lines changed

5 files changed

+295
-15
lines changed

scripts/convert_from_hf.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
from pathlib import Path
9+
10+
import torch
11+
import torch.distributed.checkpoint as dcp
12+
import torchtitan.protocols.train_spec as train_spec_module
13+
from torch.distributed.checkpoint import HuggingFaceStorageReader
14+
from torchtitan.components.checkpoint import ModelWrapper
15+
from torchtitan.components.tokenizer import build_hf_tokenizer
16+
from torchtitan.config_manager import ConfigManager
17+
18+
19+
@torch.inference_mode()
20+
def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
21+
# initialize model to allocate memory for state dict
22+
train_spec = train_spec_module.get_train_spec(model_name)
23+
model_args = train_spec.model_args[model_flavor]
24+
25+
config_manager = ConfigManager()
26+
config = config_manager.parse_args(
27+
[
28+
"--model.tokenizer-path",
29+
"./assets/tokenizer/Llama-3.1-8B",
30+
]
31+
)
32+
tokenizer = build_hf_tokenizer(config)
33+
model_args.update_from_config(config, tokenizer)
34+
with torch.device("cpu"):
35+
model = train_spec.model_cls(model_args)
36+
model = ModelWrapper(model)
37+
38+
sd_adapter = train_spec.state_dict_adapter
39+
assert (
40+
sd_adapter is not None
41+
), "trying to convert checkpoint from HF to DCP safetensors format, but sd_adapter is not provided."
42+
# get state dict in tt format with allocated memory
43+
state_dict = model._get_state_dict()
44+
# convert empty state dict to hf format so that hf weights can be loaded into it
45+
hf_state_dict = sd_adapter.to_hf(state_dict, model_args)
46+
dcp.load(
47+
hf_state_dict,
48+
storage_reader=HuggingFaceStorageReader(path=input_dir),
49+
)
50+
# convert state dict format back hf->tt and save
51+
state_dict = sd_adapter.from_hf(hf_state_dict, model_args)
52+
dcp.save(
53+
state_dict,
54+
checkpoint_id=output_dir,
55+
)
56+
57+
58+
if __name__ == "__main__":
59+
init_logger()
60+
parser = argparse.ArgumentParser(description="Convert Llama weights to DCP format.")
61+
parser.add_argument(
62+
"input_dir", type=Path, help="Input directory with original Llama weights."
63+
)
64+
parser.add_argument("output_dir", type=Path, help="Output directory for DCP.")
65+
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
66+
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
67+
args = parser.parse_args()
68+
69+
convert_from_hf(
70+
args.input_dir,
71+
args.output_dir,
72+
args.model_name,
73+
args.model_flavor,
74+
)

scripts/convert_to_hf.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
from pathlib import Path
9+
10+
import torch
11+
import torch.distributed.checkpoint as dcp
12+
import torchtitan.protocols.train_spec as train_spec_module
13+
from torch.distributed.checkpoint import HuggingFaceStorageWriter
14+
from torchtitan.components.checkpoint import ModelWrapper
15+
from torchtitan.components.tokenizer import build_hf_tokenizer
16+
from torchtitan.config_manager import ConfigManager
17+
18+
19+
@torch.inference_mode()
20+
def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
21+
# load model and model args so that we can get the state dict shape
22+
train_spec = train_spec_module.get_train_spec(model_name)
23+
model_args = train_spec.model_args[model_flavor]
24+
25+
config_manager = ConfigManager()
26+
config = config_manager.parse_args(
27+
[
28+
"--model.tokenizer-path",
29+
"./assets/tokenizer/Llama-3.1-8B",
30+
]
31+
)
32+
tokenizer = build_hf_tokenizer(config)
33+
model_args.update_from_config(config, tokenizer)
34+
with torch.device("cpu"):
35+
model = train_spec.model_cls(model_args)
36+
model = ModelWrapper(model)
37+
38+
sd_adapter = train_spec.state_dict_adapter
39+
assert (
40+
sd_adapter is not None
41+
), "trying to convert checkpoint from DCP to HF safetensors format, but sd_adapter is not provided."
42+
43+
# allocate state dict memory with empty weights to load checkpoint
44+
state_dict = model._get_state_dict()
45+
dcp.load(
46+
state_dict,
47+
checkpoint_id=input_dir,
48+
)
49+
50+
# convert state dict tt->hf
51+
hf_state_dict = sd_adapter.to_hf(state_dict, model_args)
52+
53+
fqn_to_index_mapping = {}
54+
num_fqns_per_file = 30
55+
56+
for i, key in enumerate(hf_state_dict.keys()):
57+
group_num = (i // num_fqns_per_file) + 1
58+
fqn_to_index_mapping[key] = group_num
59+
60+
storage_writer = HuggingFaceStorageWriter(
61+
path=output_dir,
62+
save_distributed=True,
63+
fqn_to_index_mapping=fqn_to_index_mapping,
64+
enable_consolidation=True,
65+
thread_count_consolidation=5,
66+
)
67+
68+
dcp.save(
69+
hf_state_dict,
70+
storage_writer=storage_writer,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
parser = argparse.ArgumentParser(description="Convert Llama weights to HF format.")
76+
parser.add_argument(
77+
"input_dir", type=Path, help="Input directory with original Llama weights."
78+
)
79+
parser.add_argument("output_dir", type=Path, help="Output directory for DCP.")
80+
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
81+
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
82+
args = parser.parse_args()
83+
84+
convert_to_hf(
85+
args.input_dir,
86+
args.output_dir,
87+
args.model_name,
88+
args.model_flavor,
89+
)

torchtitan/components/checkpoint.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ def dcp_save(
348348
checkpoint_save_id: str | None = None
349349
if to_hf:
350350
assert self.sd_adapter is not None
351-
state_dict = self.sd_adapter.to_hf(state_dict)
351+
state_dict = self.sd_adapter.to_hf(
352+
state_dict, self.states["train_state"].model_args
353+
)
352354

353355
fqn_to_index_mapping = {}
354356
num_fqns_per_file = 30
@@ -415,14 +417,18 @@ def dcp_load(
415417
assert (
416418
self.sd_adapter is not None
417419
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
418-
hf_state_dict = self.sd_adapter.to_hf(state_dict)
420+
hf_state_dict = self.sd_adapter.to_hf(
421+
state_dict, self.states["train_state"].model_args
422+
)
419423

420424
dcp.load(
421425
hf_state_dict,
422426
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
423427
)
424428

425-
state_dict = self.sd_adapter.from_hf(hf_state_dict)
429+
state_dict = self.sd_adapter.from_hf(
430+
hf_state_dict, self.states["train_state"].model_args
431+
)
426432
self.states[MODEL].load_state_dict(state_dict)
427433
else:
428434
dcp.load(state_dict, checkpoint_id=checkpoint_id)

torchtitan/models/llama3/model/state_dict_adapter.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,129 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import re
78
from typing import Any
89

910
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
1011

12+
from .args import TransformerModelArgs
13+
1114

1215
class Llama3StateDictAdapter(StateDictAdapter):
16+
from_hf_map = {
17+
"model.embed_tokens.weight": "tok_embeddings.weight",
18+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
19+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
20+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
21+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
22+
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
23+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
24+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
25+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
26+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
27+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
28+
"model.norm.weight": "norm.weight",
29+
"lm_head.weight": "output.weight",
30+
}
31+
to_hf_map = {v: k for k, v in from_hf_map.items()}
32+
33+
# HuggingFace permutation function (exact copy from their conversion script)
1334
@staticmethod
14-
def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]:
15-
# TODO: implement this
16-
return state_dict
35+
def _permute(w, n_heads_arg, dim1=None, dim2=None):
36+
if dim1 is None:
37+
dim1 = w.shape[0]
38+
if dim2 is None:
39+
dim2 = w.shape[1]
40+
return (
41+
w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2)
42+
.transpose(1, 2)
43+
.reshape(dim1, dim2)
44+
.clone()
45+
)
46+
47+
@staticmethod
48+
def _reverse_permute(w, n_heads_arg, dim1=None, dim2=None):
49+
if dim1 is None:
50+
dim1 = w.shape[0]
51+
if dim2 is None:
52+
dim2 = w.shape[1]
53+
return (
54+
w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2)
55+
.transpose(1, 2)
56+
.reshape(dim1, dim2)
57+
)
1758

1859
@staticmethod
19-
def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]:
20-
# TODO: implement this
60+
def to_hf(
61+
state_dict: dict[str, Any], model_args: TransformerModelArgs
62+
) -> dict[str, Any]:
63+
64+
n_heads = model_args.n_heads
65+
n_kv_heads = (
66+
model_args.n_kv_heads if model_args.n_kv_heads is not None else n_heads
67+
)
68+
dim = model_args.dim
69+
head_dim = dim // n_heads
70+
hf_state_dict = {}
71+
72+
for key, value in state_dict.items():
73+
if "layers" in key:
74+
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
75+
layer_num = re.search(r"\d+", key).group(0)
76+
new_key = Llama3StateDictAdapter.to_hf_map[abstract_key]
77+
# We need to permute the weights in wq and wk layer in order to account for the difference between
78+
# the native Llama and huggingface RoPE implementation.
79+
if abstract_key == "layers.{}.attention.wq.weight":
80+
value = Llama3StateDictAdapter._permute(value, n_heads)
81+
if abstract_key == "layers.{}.attention.wk.weight":
82+
key_value_dim = head_dim * n_kv_heads
83+
value = Llama3StateDictAdapter._permute(
84+
value, n_kv_heads, key_value_dim, dim
85+
)
86+
87+
if new_key is None:
88+
continue
89+
new_key = new_key.format(layer_num)
90+
else:
91+
new_key = Llama3StateDictAdapter.to_hf_map[key]
92+
93+
hf_state_dict[new_key] = value
2194
return hf_state_dict
95+
96+
@staticmethod
97+
def from_hf(
98+
hf_state_dict: dict[str, Any], model_args: TransformerModelArgs
99+
) -> dict[str, Any]:
100+
n_heads = model_args.n_heads
101+
n_kv_heads = (
102+
model_args.n_kv_heads if model_args.n_kv_heads is not None else n_heads
103+
)
104+
dim = model_args.dim
105+
head_dim = dim // n_heads
106+
state_dict = {}
107+
108+
for key, value in hf_state_dict.items():
109+
if "layers" in key:
110+
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
111+
layer_num = re.search(r"\d+", key).group(0)
112+
new_key = Llama3StateDictAdapter.from_hf_map[abstract_key]
113+
print(f"{new_key} in layer {layer_num}")
114+
115+
# We need to permute the weights in wq and wk layer in order to account for the difference between
116+
# the native Llama and huggingface RoPE implementation.
117+
if abstract_key == "model.layers.{}.self_attn.q_proj.weight":
118+
value = Llama3StateDictAdapter._reverse_permute(value, n_heads)
119+
if abstract_key == "model.layers.{}.self_attn.k_proj.weight":
120+
key_value_dim = head_dim * n_kv_heads
121+
value = Llama3StateDictAdapter._reverse_permute(
122+
value, n_kv_heads, key_value_dim, dim
123+
)
124+
125+
if new_key is None:
126+
continue
127+
new_key = new_key.format(layer_num)
128+
else:
129+
new_key = Llama3StateDictAdapter.from_hf_map[key]
130+
131+
state_dict[new_key] = value
132+
return state_dict

torchtitan/train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ def __init__(self, job_config: JobConfig):
137137
)
138138

139139
# build model (using meta init)
140-
model_args = self.train_spec.model_args[job_config.model.flavor]
140+
self.model_args = self.train_spec.model_args[job_config.model.flavor]
141141
# set the model args from training job configs
142-
model_args.update_from_config(job_config, tokenizer)
142+
self.model_args.update_from_config(job_config, tokenizer)
143143

144144
logger.info(
145-
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
145+
f"Building {self.train_spec.name} {job_config.model.flavor} with {self.model_args}"
146146
)
147147
with torch.device("meta"):
148-
model = self.train_spec.model_cls(model_args)
148+
model = self.train_spec.model_cls(self.model_args)
149149

150150
# Build the collection of model converters. No-op if `model.converters` empty
151151
model_converters = build_model_converters(job_config, parallel_dims)
@@ -158,15 +158,15 @@ def __init__(self, job_config: JobConfig):
158158
else self.train_spec.build_metrics_processor_fn
159159
)
160160
self.metrics_processor = build_metrics_processor_fn(
161-
job_config, parallel_dims, model_args
161+
job_config, parallel_dims, self.model_args
162162
)
163163
color = self.metrics_processor.color
164164

165165
# calculate model size and flops per token
166166
(
167167
model_param_count,
168168
self.metrics_processor.num_flops_per_token,
169-
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
169+
) = self.model_args.get_nparams_and_flops(model, job_config.training.seq_len)
170170

171171
logger.info(
172172
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
@@ -229,7 +229,7 @@ def __init__(self, job_config: JobConfig):
229229
parallel_dims,
230230
job_config,
231231
self.device,
232-
model_args,
232+
self.model_args,
233233
self.train_spec.parallelize_fn,
234234
self.loss_fn,
235235
)

0 commit comments

Comments
 (0)