File tree Expand file tree Collapse file tree 1 file changed +3
-10
lines changed Expand file tree Collapse file tree 1 file changed +3
-10
lines changed Original file line number Diff line number Diff line change @@ -4122,16 +4122,9 @@ def save_pretrained(
41224122 for shard_file , tensors in filename_to_tensors :
41234123 shard = {}
41244124 for tensor in tensors :
4125- if _is_dtensor_available and getattr (self , "_device_mesh" , None ) is not None :
4126- plan = _get_parameter_tp_plan (tensor , self ._tp_plan )
4127- full_tensor = state_dict [tensor ]
4128- if isinstance (state_dict [tensor ], DTensor ):
4129- full_tensor = full_tensor .full_tensor ()
4130- elif plan is not None :
4131- shard_dim = - 1 if "rowwise" in plan else 0
4132- gather_list = [torch .empty_like (full_tensor ) for _ in range (self ._device_mesh .size ())]
4133- torch .distributed .all_gather (gather_list , full_tensor )
4134- full_tensor = torch .cat (gather_list , dim = shard_dim )
4125+ if _is_dtensor_available and isinstance (state_dict [tensor ], DTensor ):
4126+ full_tensor = state_dict [tensor ].full_tensor ()
4127+ # to get the correctly ordered tensor we need to repack if packed
41354128 if _get_parameter_tp_plan (tensor , self ._tp_plan ) in ("local_packed_rowwise" ,):
41364129 full_tensor = repack_weights (full_tensor , - 1 , self ._tp_size , 2 )
41374130 shard [tensor ] = full_tensor .contiguous () # only do contiguous after it's permuted correctly
You can’t perform that action at this time.
0 commit comments