Skip to content

Commit 730d2a5

Browse files
authored
DeepSpeed tensor parallel+ZeRO (#36825)
add ds tp change
1 parent 1a37479 commit 730d2a5

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/transformers/integrations/deepspeed.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
464464
model_parameters = None
465465
else:
466466
trainer.optimizer = None # important for when deepspeed_init is used as re-init
467+
tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 0)
468+
if tp_size > 1:
469+
import deepspeed
470+
471+
model = deepspeed.tp_model_init(model=model, tp_size=tp_size, dtype=hf_deepspeed_config.dtype())
467472
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
468473
optimizer, lr_scheduler = deepspeed_optim_sched(
469474
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters

0 commit comments

Comments
 (0)