Skip to content

Commit ed511f2

Browse files
committed
Revert "remove dtensors, not explicit (#39840)"
This reverts commit 6dfd561.
1 parent d2303c7 commit ed511f2

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

src/transformers/modeling_utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4122,16 +4122,9 @@ def save_pretrained(
41224122
for shard_file, tensors in filename_to_tensors:
41234123
shard = {}
41244124
for tensor in tensors:
4125-
if _is_dtensor_available and getattr(self, "_device_mesh", None) is not None:
4126-
plan = _get_parameter_tp_plan(tensor, self._tp_plan)
4127-
full_tensor = state_dict[tensor]
4128-
if isinstance(state_dict[tensor], DTensor):
4129-
full_tensor = full_tensor.full_tensor()
4130-
elif plan is not None:
4131-
shard_dim = -1 if "rowwise" in plan else 0
4132-
gather_list = [torch.empty_like(full_tensor) for _ in range(self._device_mesh.size())]
4133-
torch.distributed.all_gather(gather_list, full_tensor)
4134-
full_tensor = torch.cat(gather_list, dim=shard_dim)
4125+
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
4126+
full_tensor = state_dict[tensor].full_tensor()
4127+
# to get the correctly ordered tensor we need to repack if packed
41354128
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
41364129
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
41374130
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly

0 commit comments

Comments
 (0)