@@ -335,6 +335,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
335335 tp_rank = get_tensor_model_parallel_rank ()
336336 output_dim = getattr (param , "output_dim" , None )
337337
338+ is_sharded_weight = getattr (param , "is_sharded_weight" , False )
339+ use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
340+ # bitsandbytes loads the weights of the specific portion
341+ # no need to narrow
342+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
343+
338344 # Special case for GGUF
339345 is_gguf_weight = getattr (param , "is_gguf_weight" , False )
340346 is_gguf_weight_type = getattr (param , "is_gguf_weight_type" , False )
@@ -343,13 +349,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
343349
344350 # Materialize GGUF UninitializedParameter
345351 if is_gguf_weight and isinstance (param , UninitializedParameter ):
346- param .materialize (loaded_weight .shape , dtype = loaded_weight .dtype )
347-
348- use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
349- is_sharded_weight = getattr (param , "is_sharded_weight" , False )
350- # bitsandbytes loads the weights of the specific portion
351- # no need to narrow
352- is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
352+ final_shape = list (loaded_weight .shape )
353+ if output_dim is not None :
354+ tp_size = get_tensor_model_parallel_world_size ()
355+ assert final_shape [output_dim ] % tp_size == 0
356+ final_shape [output_dim ] = final_shape [output_dim ] // tp_size
357+ param .materialize (final_shape , dtype = loaded_weight .dtype )
353358
354359 param_data = param .data
355360 if output_dim is not None and not is_sharded_weight :
0 commit comments