Skip to content

Commit 581f912

Browse files
authored
Merge pull request #2 from huggingface/tp-training
Tensor Parallel Training support
2 parents 8bc3475 + eec466e commit 581f912

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
9191
device_map = tp_device
9292
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
9393
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
94-
return tp_device, device_map, device_mesh
94+
return tp_device, device_map, device_mesh, tp_size
9595

9696

9797
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, list[int]]) -> list[int]:

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4487,7 +4487,7 @@ def from_pretrained(
44874487
# `device_map` pointing to the correct device
44884488
if tp_plan is not None:
44894489
if device_mesh is None and tp_plan is not None:
4490-
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
4490+
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
44914491
else:
44924492
# TODO: make device_mesh support multiple dimensions
44934493
if device_mesh.ndim > 1:

src/transformers/models/openai_moe/modeling_openai_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def forward(self, hidden_states):
137137
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
138138
router_logits = self.router(hidden_states)
139139
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
140-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1)
140+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
141141
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value).transpose(0, 1)
142142
routed_out = self.experts(hidden_states, router_indices, router_top_value)
143143
if self.training:

src/transformers/trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,9 +2294,7 @@ def _inner_training_loop(
22942294
else:
22952295
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
22962296

2297-
delay_optimizer_creation = (
2298-
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
2299-
)
2297+
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
23002298

23012299
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
23022300
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
@@ -2356,8 +2354,8 @@ def _inner_training_loop(
23562354
if self.use_apex:
23572355
model = self.accelerator.prepare(self.model)
23582356
else:
2359-
if delay_optimizer_creation:
2360-
model = self.accelerator.prepare(self.model)
2357+
if self.is_tp_enabled:
2358+
self.optimizer = self.accelerator.prepare(self.optimizer)
23612359
else:
23622360
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
23632361
else:
@@ -2604,7 +2602,12 @@ def _inner_training_loop(
26042602

26052603
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
26062604

2607-
self.optimizer.step()
2605+
context = contextlib.nullcontext
2606+
if self.is_tp_enabled:
2607+
from torch.distributed._tensor.experimental import implicit_replication
2608+
context = implicit_replication
2609+
with context():
2610+
self.optimizer.step()
26082611

26092612
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
26102613

0 commit comments

Comments
 (0)