Skip to content
5 changes: 5 additions & 0 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
model_parameters = None
else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 0)
if tp_size > 1:
import deepspeed

model = deepspeed.tp_model_init(model=model, tp_size=tp_size, dtype=hf_deepspeed_config.dtype())
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer, lr_scheduler = deepspeed_optim_sched(
trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
Expand Down