From 36fead398fb44ba8d478fe7a16db7ea8567ad8c2 Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Thu, 1 Sep 2022 11:33:53 -0300 Subject: [PATCH] Adds timeout argument to training_args to avoid socket timeouts in DDP (#18562) * chore(training_args): Adds support for timeout argument. * fix(training_args): Passes make style through changes. * fix(training_args): Removes wrong docstring sentence. * fix(training_args): Fixes timeout not being JSON serializable. * fix(training_args_sm): Also updates timeout to timeout_delta. * fix(training_args): Fixes PR according to suggestions. --- .../sagemaker/training_args_sm.py | 4 +-- src/transformers/training_args.py | 27 ++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/transformers/sagemaker/training_args_sm.py b/src/transformers/sagemaker/training_args_sm.py index 6be0deb1f479d0..e4a356a25b180f 100644 --- a/src/transformers/sagemaker/training_args_sm.py +++ b/src/transformers/sagemaker/training_args_sm.py @@ -92,7 +92,7 @@ def _setup_devices(self) -> "torch.device": elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 - torch.distributed.init_process_group(backend="smddp") + torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta) self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 @@ -111,7 +111,7 @@ def _setup_devices(self) -> "torch.device": # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8e6d91084c208a..646e9343571b99 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -18,6 +18,7 @@ import os import warnings from dataclasses import asdict, dataclass, field +from datetime import timedelta from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -481,6 +482,11 @@ class TrainingArguments: are also available. See the [Ray documentation]( https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for more options. + ddp_timeout (`int`, *optional*, defaults to 1800): + The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when + performing slow operations in distributed runnings. Please refer the [PyTorch documentation] + (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more + information. use_mps_device (`bool`, *optional*, defaults to `False`): Whether to use Apple Silicon chip based `mps` device. """ @@ -971,6 +977,12 @@ class TrainingArguments: ) }, ) + ddp_timeout: Optional[int] = field( + default=1800, + metadata={ + "help": "Overrides the default timeout for distributed training (value should be given in seconds)." + }, + ) def __post_init__(self): # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). @@ -1291,6 +1303,13 @@ def eval_batch_size(self) -> int: eval_batch_size = per_device_batch_size * max(1, self.n_gpu) return eval_batch_size + @property + def ddp_timeout_delta(self) -> timedelta: + """ + The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable. + """ + return timedelta(seconds=self.ddp_timeout) + @cached_property @torch_required def _setup_devices(self) -> "torch.device": @@ -1358,7 +1377,9 @@ def _setup_devices(self) -> "torch.device": f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob" " performance." ) - torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size) + torch.distributed.init_process_group( + backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta + ) elif is_torch_tpu_available(): device = xm.xla_device() self._n_gpu = 0 @@ -1369,7 +1390,7 @@ def _setup_devices(self) -> "torch.device": elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 - dist.init_process_group(backend="smddp") + dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta) self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 @@ -1431,7 +1452,7 @@ def _setup_devices(self) -> "torch.device": # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) device = torch.device("cuda", self.local_rank) self._n_gpu = 1