Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds timeout argument to training_args to avoid socket timeouts in DDP #18562

Merged
merged 8 commits into from
Sep 1, 2022
4 changes: 2 additions & 2 deletions src/transformers/sagemaker/training_args_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _setup_devices(self) -> "torch.device":
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401

torch.distributed.init_process_group(backend="smddp")
torch.distributed.init_process_group(backend="smddp", timeout=self.timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
Expand All @@ -111,7 +111,7 @@ def _setup_devices(self) -> "torch.device":
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down
29 changes: 26 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -963,6 +964,19 @@ class TrainingArguments:
)
},
)
timeout: Optional[int] = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timeout: Optional[int] = field(
ddp_timeout: Optional[int] = field(

Let's make it clear this is a DDP argument.

default=1800,
metadata={
"help": (
"Overrides the default timeout defined by PyTorch and"
" introduces a way to prevent Socket Timeout when mapping large datasets."
" Expects timeout in seconds. Used for timeout argument in"
" torch.distributed.init_process_group calls. Please refer the PyTorch documentation"
" https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group"
" for more information."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit too long here, and the docstring for this new argument is missing. I'd limit the help here to

Suggested change
"Overrides the default timeout defined by PyTorch and"
" introduces a way to prevent Socket Timeout when mapping large datasets."
" Expects timeout in seconds. Used for timeout argument in"
" torch.distributed.init_process_group calls. Please refer the PyTorch documentation"
" https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group"
" for more information."
"Overrides the default timeout for distributed training (value should be given in seconds).
```"
and you can add more info in the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds perfect! Thanks @sgugger. I will push the changes in a couple of minutes.

)
},
)

def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
Expand Down Expand Up @@ -1283,6 +1297,13 @@ def eval_batch_size(self) -> int:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
return eval_batch_size

@property
def timeout_delta(self) -> timedelta:
"""
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
"""
return timedelta(seconds=self.timeout)

@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
Expand Down Expand Up @@ -1335,7 +1356,9 @@ def _setup_devices(self) -> "torch.device":
"Looks like distributed multinode run but MASTER_ADDR env not set, "
"please try exporting rank 0's hostname as MASTER_ADDR"
)
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
torch.distributed.init_process_group(
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.timeout_delta
)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
Expand All @@ -1346,7 +1369,7 @@ def _setup_devices(self) -> "torch.device":
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401

dist.init_process_group(backend="smddp")
dist.init_process_group(backend="smddp", timeout=self.timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
Expand Down Expand Up @@ -1382,7 +1405,7 @@ def _setup_devices(self) -> "torch.device":
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down