Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing error when using sharded ddp #18435

Merged
merged 1 commit into from
Aug 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,9 +1344,8 @@ def _wrap_model(self, model, training=True, dataloader=None):
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)

# Distributed training using PyTorch FSDP
if self.fsdp is not None:
elif self.fsdp is not None:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -1394,7 +1393,6 @@ def _wrap_model(self, model, training=True, dataloader=None):
)
if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device)

elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
Expand Down