Skip to content

Commit 1232662

Browse files
authored
Add config for Llama3 70B (#145)
1 parent f0c5c96 commit 1232662

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

examples/example_llama3.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,36 @@
4242
use_vocab_parallel = not use_1d_mesh
4343
device = torch.device("cuda")
4444

45+
model_type = "8b"
46+
4547

4648
def model_fn():
47-
model_args = TransformerModelArgs(
48-
dim=4096,
49-
n_heads=32,
50-
n_layers=32,
51-
vocab_size=vocab_size,
52-
max_seq_len=seqlen,
53-
multiple_of=1024,
54-
ffn_dim_multiplier=1.3,
55-
n_kv_heads=8,
56-
)
49+
if model_type == "8b":
50+
model_args = TransformerModelArgs(
51+
dim=4096,
52+
n_layers=32,
53+
n_heads=32,
54+
n_kv_heads=8,
55+
ffn_dim_multiplier=1.3,
56+
multiple_of=1024,
57+
rope_theta=500000,
58+
vocab_size=vocab_size,
59+
max_seq_len=seqlen,
60+
)
61+
elif model_type == "70b":
62+
model_args = TransformerModelArgs(
63+
dim=8192,
64+
n_layers=80,
65+
n_heads=64,
66+
n_kv_heads=8,
67+
ffn_dim_multiplier=1.3,
68+
multiple_of=4096,
69+
rope_theta=500000,
70+
vocab_size=vocab_size,
71+
max_seq_len=seqlen,
72+
)
73+
else:
74+
raise ValueError(f"{model_type} not available")
5775
m = Transformer(model_args)
5876
return m
5977

0 commit comments

Comments
 (0)