From aa509e2d48e67e0e8814317c08ec08e559b6671c Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 3 Aug 2022 01:41:07 +0530 Subject: [PATCH] fixing error when using sharded ddp --- src/transformers/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 59a1ca19a62c4f..37a21b0939c9e2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1344,9 +1344,8 @@ def _wrap_model(self, model, training=True, dataloader=None): reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) - # Distributed training using PyTorch FSDP - if self.fsdp is not None: + elif self.fsdp is not None: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -1394,7 +1393,6 @@ def _wrap_model(self, model, training=True, dataloader=None): ) if FSDPOption.OFFLOAD not in self.args.fsdp: model.to(self.args.device) - elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]