diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 59a1ca19a62c4f..37a21b0939c9e2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1344,9 +1344,8 @@ def _wrap_model(self, model, training=True, dataloader=None): reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) - # Distributed training using PyTorch FSDP - if self.fsdp is not None: + elif self.fsdp is not None: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -1394,7 +1393,6 @@ def _wrap_model(self, model, training=True, dataloader=None): ) if FSDPOption.OFFLOAD not in self.args.fsdp: model.to(self.args.device) - elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]