@@ -314,7 +314,8 @@ def forward(
314314 hidden_states = self .norm (hidden_states )
315315 return hidden_states
316316
317- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
317+ def load_weights (self , weights : Iterable [tuple [str ,
318+ torch .Tensor ]]) -> set [str ]:
318319 stacked_params_mapping = [
319320 # (param_name, shard_name, shard_id)
320321 ("qkv_proj" , "q_proj" , "q" ),
@@ -325,6 +326,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
325326 ]
326327
327328 params_dict = dict (self .named_parameters (remove_duplicate = False ))
329+ loaded_params : set [str ] = set ()
328330 for name , loaded_weight in weights :
329331 if is_pp_missing_parameter (name , self ):
330332 continue
@@ -347,6 +349,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
347349 weight_loader = getattr (param , "weight_loader" ,
348350 default_weight_loader )
349351 weight_loader (param , loaded_weight )
352+ loaded_params .add (name )
353+ return loaded_params
350354
351355
352356class Olmo2ForCausalLM (nn .Module , SupportsPP ):
0 commit comments