-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds timeout argument to training_args to avoid socket timeouts in DDP #18562
Changes from 6 commits
fb70aa8
ef0cc25
ae26620
e0f710d
3735e84
f5c652e
a2ed004
d40d993
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,6 +18,7 @@ | |||||||||||||||||||
import os | ||||||||||||||||||||
import warnings | ||||||||||||||||||||
from dataclasses import asdict, dataclass, field | ||||||||||||||||||||
from datetime import timedelta | ||||||||||||||||||||
from enum import Enum | ||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||
from typing import Any, Dict, List, Optional, Union | ||||||||||||||||||||
|
@@ -963,6 +964,19 @@ class TrainingArguments: | |||||||||||||||||||
) | ||||||||||||||||||||
}, | ||||||||||||||||||||
) | ||||||||||||||||||||
timeout: Optional[int] = field( | ||||||||||||||||||||
default=1800, | ||||||||||||||||||||
metadata={ | ||||||||||||||||||||
"help": ( | ||||||||||||||||||||
"Overrides the default timeout defined by PyTorch and" | ||||||||||||||||||||
" introduces a way to prevent Socket Timeout when mapping large datasets." | ||||||||||||||||||||
" Expects timeout in seconds. Used for timeout argument in" | ||||||||||||||||||||
" torch.distributed.init_process_group calls. Please refer the PyTorch documentation" | ||||||||||||||||||||
" https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group" | ||||||||||||||||||||
" for more information." | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit too long here, and the docstring for this new argument is missing. I'd limit the help here to
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds perfect! Thanks @sgugger. I will push the changes in a couple of minutes. |
||||||||||||||||||||
) | ||||||||||||||||||||
}, | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
def __post_init__(self): | ||||||||||||||||||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). | ||||||||||||||||||||
|
@@ -1283,6 +1297,13 @@ def eval_batch_size(self) -> int: | |||||||||||||||||||
eval_batch_size = per_device_batch_size * max(1, self.n_gpu) | ||||||||||||||||||||
return eval_batch_size | ||||||||||||||||||||
|
||||||||||||||||||||
@property | ||||||||||||||||||||
def timeout_delta(self) -> timedelta: | ||||||||||||||||||||
""" | ||||||||||||||||||||
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable. | ||||||||||||||||||||
""" | ||||||||||||||||||||
return timedelta(seconds=self.timeout) | ||||||||||||||||||||
|
||||||||||||||||||||
@cached_property | ||||||||||||||||||||
@torch_required | ||||||||||||||||||||
def _setup_devices(self) -> "torch.device": | ||||||||||||||||||||
|
@@ -1335,7 +1356,9 @@ def _setup_devices(self) -> "torch.device": | |||||||||||||||||||
"Looks like distributed multinode run but MASTER_ADDR env not set, " | ||||||||||||||||||||
"please try exporting rank 0's hostname as MASTER_ADDR" | ||||||||||||||||||||
) | ||||||||||||||||||||
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size) | ||||||||||||||||||||
torch.distributed.init_process_group( | ||||||||||||||||||||
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.timeout_delta | ||||||||||||||||||||
) | ||||||||||||||||||||
elif is_torch_tpu_available(): | ||||||||||||||||||||
device = xm.xla_device() | ||||||||||||||||||||
self._n_gpu = 0 | ||||||||||||||||||||
|
@@ -1346,7 +1369,7 @@ def _setup_devices(self) -> "torch.device": | |||||||||||||||||||
elif is_sagemaker_dp_enabled(): | ||||||||||||||||||||
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 | ||||||||||||||||||||
|
||||||||||||||||||||
dist.init_process_group(backend="smddp") | ||||||||||||||||||||
dist.init_process_group(backend="smddp", timeout=self.timeout_delta) | ||||||||||||||||||||
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) | ||||||||||||||||||||
device = torch.device("cuda", self.local_rank) | ||||||||||||||||||||
self._n_gpu = 1 | ||||||||||||||||||||
|
@@ -1382,7 +1405,7 @@ def _setup_devices(self) -> "torch.device": | |||||||||||||||||||
# Here, we'll use torch.distributed. | ||||||||||||||||||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs | ||||||||||||||||||||
if not torch.distributed.is_initialized(): | ||||||||||||||||||||
torch.distributed.init_process_group(backend="nccl") | ||||||||||||||||||||
torch.distributed.init_process_group(backend="nccl", timeout=self.timeout_delta) | ||||||||||||||||||||
device = torch.device("cuda", self.local_rank) | ||||||||||||||||||||
self._n_gpu = 1 | ||||||||||||||||||||
|
||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make it clear this is a DDP argument.