File tree Expand file tree Collapse file tree 1 file changed +28
-10
lines changed Expand file tree Collapse file tree 1 file changed +28
-10
lines changed Original file line number Diff line number Diff line change 4242use_vocab_parallel = not use_1d_mesh
4343device = torch .device ("cuda" )
4444
45+ model_type = "8b"
46+
4547
4648def model_fn ():
47- model_args = TransformerModelArgs (
48- dim = 4096 ,
49- n_heads = 32 ,
50- n_layers = 32 ,
51- vocab_size = vocab_size ,
52- max_seq_len = seqlen ,
53- multiple_of = 1024 ,
54- ffn_dim_multiplier = 1.3 ,
55- n_kv_heads = 8 ,
56- )
49+ if model_type == "8b" :
50+ model_args = TransformerModelArgs (
51+ dim = 4096 ,
52+ n_layers = 32 ,
53+ n_heads = 32 ,
54+ n_kv_heads = 8 ,
55+ ffn_dim_multiplier = 1.3 ,
56+ multiple_of = 1024 ,
57+ rope_theta = 500000 ,
58+ vocab_size = vocab_size ,
59+ max_seq_len = seqlen ,
60+ )
61+ elif model_type == "70b" :
62+ model_args = TransformerModelArgs (
63+ dim = 8192 ,
64+ n_layers = 80 ,
65+ n_heads = 64 ,
66+ n_kv_heads = 8 ,
67+ ffn_dim_multiplier = 1.3 ,
68+ multiple_of = 4096 ,
69+ rope_theta = 500000 ,
70+ vocab_size = vocab_size ,
71+ max_seq_len = seqlen ,
72+ )
73+ else :
74+ raise ValueError (f"{ model_type } not available" )
5775 m = Transformer (model_args )
5876 return m
5977
You can’t perform that action at this time.
0 commit comments