5555 # Unused keys in load hooks (explicitly removed)
5656 r'layers.(\d+).attention.wqkv._extra_state' : None ,
5757 r'layers.(\d+).attention.wo._extra_state' : None ,
58+
59+ # MLP layer variant
60+ r"layers.(\d+).feed_forward.w1.weight" : r"language_model.model.layers.\1.feed_forward.gate_proj.weight" , # might need to be fused for efficiency?
61+ r"layers.(\d+).feed_forward.w3.weight" : r"language_model.model.layers.\1.feed_forward.up_proj.weight" , # might need to be fused for efficiency?
62+ # r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
63+ r"layers.(\d+).feed_forward.mlp.fc2_weight" : r"language_model.model.layers.\1.feed_forward.down_proj.weight" ,
64+ r"layers.(\d+).feed_forward.mlp.layer_norm.weight" : r"language_model.model.layers.\1.post_attention_layernorm.weight" ,
65+
5866 # Vision encoder mapping
5967 r"vision_embeddings.vision_encoder.conv1._linear" : r"vision_model.patch_embedding.linear" ,
6068 r'vision_embeddings.vision_adapter.mlp.c_fc' : r"vision_model.vision_adapter.mlp.fc1" ,
@@ -142,6 +150,9 @@ def get_concat_dim(key):
142150 "experts.gate_proj" ,
143151 "experts.up_proj" ,
144152 "expert.down_proj" ,
153+ # "feed_forward.up_proj",
154+ # "feed_forward.gate_proj",
155+ "feed_forward.down_proj" ,
145156 "global_gate_stats" ,
146157 # vision dim1 sharded stuff
147158 "mlp.fc2.weight" , # covers all rowparallels across vis
@@ -166,6 +177,20 @@ def safe_load(filename):
166177 return shard
167178
168179
180+ # Unpack mlp projections - possibly to be removed when they are fused
181+ def preprocess_keys (state_dict ):
182+ new_state_dict = dict ()
183+ for key , value in state_dict .items ():
184+ if "mlp.fc1_weight" in key :
185+ prefix = key .split ("mlp.fc1_weight" )[0 ]
186+ w1 , w3 = value .chunk (2 , dim = 0 )
187+ new_state_dict [prefix + "w1.weight" ] = w1
188+ new_state_dict [prefix + "w3.weight" ] = w3
189+ else :
190+ new_state_dict [key ] = value
191+ return new_state_dict
192+
193+
169194def write_model (
170195 model_path ,
171196 input_base_path ,
@@ -194,14 +219,17 @@ def write_model(
194219 rms_norm_eps = params ["norm_eps" ]
195220 rope_theta = params ["rope_theta" ]
196221
197- # some constans from original code
198- rope_scaling = {
199- "rope_type" : "llama3" ,
200- "factor" : 8.0 ,
201- "low_freq_factor" : 1.0 ,
202- "high_freq_factor" : 4.0 ,
203- "original_max_position_embeddings" : 8192 ,
204- }
222+ config_kwargs = {}
223+ if params ["use_scaled_rope" ]:
224+ # some constans from original code
225+ rope_scaling = {
226+ "rope_type" : "llama3" ,
227+ "factor" : 8.0 ,
228+ "low_freq_factor" : 1.0 ,
229+ "high_freq_factor" : 4.0 ,
230+ "original_max_position_embeddings" : 8192 ,
231+ }
232+ config_kwargs .update (dict (rope_scaling = rope_scaling ))
205233
206234 # compute additional params for weight conversion
207235 num_heads_per_shard = num_heads // num_shards
@@ -211,9 +239,10 @@ def write_model(
211239 num_key_value_heads = params ["n_kv_heads" ] # for GQA / MQA
212240
213241 num_experts = params ["moe_args" ]["num_experts" ]
242+ interleave_moe_layer_step = params ["moe_args" ].get ("interleave_moe_layer_step" , 1 )
214243
215244 bos_token_id = 200000
216- eos_token_id = [200001 , 200002 , 200003 ] if instruct else 200001
245+ eos_token_id = [200001 , 200002 , 200003 , 200008 ] if instruct else 200001
217246 pad_token_id = 200008
218247
219248 text_config = Llama4TextConfig (
@@ -224,13 +253,16 @@ def write_model(
224253 rope_theta = rope_theta ,
225254 num_hidden_layers = num_layers ,
226255 intermediate_size = 8192 ,
227- rope_scaling = rope_scaling ,
256+ intermediate_size_mlp = 16384 ,
228257 num_local_experts = num_experts ,
258+ interleave_moe_layer_step = interleave_moe_layer_step ,
259+ use_qk_norm = params ["use_qk_norm" ],
229260 bos_token_id = bos_token_id ,
230261 eos_token_id = eos_token_id ,
231262 pad_token_id = pad_token_id ,
232263 tie_word_embeddings = False , # Constant set to False
233264 torch_dtype = torch_dtype ,
265+ ** config_kwargs ,
234266 )
235267 # default vision config frmo params
236268
@@ -273,6 +305,7 @@ def write_model(
273305 safe_load (os .path .join (input_base_path , f"consolidated.{ i :02d} .pth" ))
274306 for i in tqdm (range (num_shards ), desc = "Loading shards" , unit = "shard" )
275307 ]
308+ loaded = [preprocess_keys (d ) for d in loaded ]
276309
277310 all_keys_raw = list (loaded [0 ].keys ())
278311 repeated_keys = []
@@ -354,7 +387,7 @@ def write_model(
354387 if gate_key == new_key :
355388 state_dict [new_key ] = torch .cat (current_parameter , dim = concat_dim )
356389 elif new_key == up_key :
357- if "shared" in new_key :
390+ if "experts" not in new_key :
358391 gate_proj = state_dict .pop (gate_key )
359392 up_proj = torch .cat (current_parameter , dim = concat_dim )
360393 state_dict [gate_key ] = gate_proj
@@ -365,11 +398,11 @@ def write_model(
365398 else :
366399 gate_proj = state_dict .pop (gate_key )
367400 gate_proj = [
368- gate_proj .reshape (16 , - 1 , 8 , 1024 )[:, :, k , :].reshape (16 , - 1 , 1024 ) for k in range (8 )
401+ gate_proj .reshape (num_experts , - 1 , 8 , 1024 )[:, :, k , :].reshape (num_experts , - 1 , 1024 ) for k in range (8 )
369402 ]
370403 gate_proj = torch .cat (gate_proj , dim = - 1 )
371404
372- up_proj = [k .reshape (16 , - 1 , 8 , 1024 ).reshape (16 , - 1 , 1024 ) for k in current_parameter ]
405+ up_proj = [k .reshape (num_experts , - 1 , 8 , 1024 ).reshape (num_experts , - 1 , 1024 ) for k in current_parameter ]
373406 up_proj = torch .cat (up_proj , dim = - 1 )
374407
375408 gate_up_proj = torch .cat ((gate_proj , up_proj ), dim = - 1 )
@@ -432,10 +465,7 @@ def write_model(
432465 print ("Loading the checkpoint in a Llama4 model." )
433466 state_dict .pop ("" )
434467 model .load_state_dict (state_dict , strict = True , assign = True )
435- print ("Model reloaded successfully. Checking logits..." )
436- # ipdb.set_trace()
437- # zero_out = model.forward(inputs_embeds=torch.zeros((1,743, 4096)))
438- # ipdb.set_trace()
468+ print ("Model reloaded successfully." )
439469 print ("Saving the model." )
440470 model .save_pretrained (model_path , safe_serialization = safe_serialization )
441471 del state_dict , model
@@ -448,8 +478,7 @@ def write_model(
448478 model = Llama4ForConditionalGeneration .from_pretrained (
449479 model_path , torch_dtype = torch .bfloat16 , device_map = "auto" , attn_implementation = "eager"
450480 )
451- # ipdb.set_trace()
452- model .eval ()
481+
453482 model .generation_config .top_p = 0.9
454483 model .generation_config .temperature = 0.6
455484 print ("Model reloaded successfully." )
@@ -458,7 +487,7 @@ def write_model(
458487
459488 tokenizer = AutoTokenizer .from_pretrained (model_path )
460489 inputs = tokenizer (["Roses are red," ], return_tensors = "pt" ).to (model .device )
461- out = model .generate (** inputs , max_new_tokens = 10 )
490+ out = model .generate (** inputs , max_new_tokens = 4 )
462491 print (tokenizer .batch_decode (out ))
463492 # generation config
464493 if instruct :
0 commit comments