Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add customization point for init_process_group kwargs #228

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def collate_fn(examples):
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=len(train_dataloader) * num_epochs,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)

# Now we train the model
Expand Down
23 changes: 16 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from packaging import version

from .data_loader import prepare_data_loader
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, KwargsHandler
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs, KwargsHandler
from .optimizer import AcceleratedOptimizer
from .state import AcceleratorState, DistributedType, is_deepspeed_available
from .utils import (
Expand Down Expand Up @@ -114,15 +114,10 @@ def __init__(
deepspeed_plugin, DeepSpeedPlugin
), "`deepspeed_plugin` must be a DeepSpeedPlugin object."

self.state = AcceleratorState(fp16=fp16, cpu=cpu, deepspeed_plugin=deepspeed_plugin, _from_accelerator=True)

self.device_placement = device_placement
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches

# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
self.init_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(handler, KwargsHandler), f"Unsupported kwargs handler passed: {handler}."
Expand All @@ -136,6 +131,20 @@ def __init__(
raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
else:
self.scaler_handler = handler
elif isinstance(handler, InitProcessGroupKwargs):
if self.init_handler is not None:
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler

kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
fp16=fp16, cpu=cpu, deepspeed_plugin=deepspeed_plugin, _from_accelerator=True, **kwargs
)

self.device_placement = device_placement
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches

# Mixed precision attributes
self.scaler = None
Expand Down
15 changes: 15 additions & 0 deletions src/accelerate/kwargs_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import copy
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional


class KwargsHandler:
Expand Down Expand Up @@ -71,3 +73,16 @@ class GradScalerKwargs(KwargsHandler):
backoff_factor: float = 0.5
growth_interval: int = 2000
enabled: bool = True


@dataclass
class InitProcessGroupKwargs(KwargsHandler):
"""
Use this object in your :class:`~accelerate.Accelerator` to customize the initialization of the distributed
processes. Please refer to the documentation of this `method
<https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__ for more information on
each argument.
"""

init_method: Optional[str] = None
timeout: timedelta = timedelta(seconds=1800)
10 changes: 6 additions & 4 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class AcceleratorState:

_shared_state = {}

def __init__(self, fp16: bool = None, cpu: bool = False, deepspeed_plugin=None, _from_accelerator: bool = False):
def __init__(
self, fp16: bool = None, cpu: bool = False, deepspeed_plugin=None, _from_accelerator: bool = False, **kwargs
):
self.__dict__ = self._shared_state
if not getattr(self, "initialized", False):
self.backend = None
Expand All @@ -161,7 +163,7 @@ def __init__(self, fp16: bool = None, cpu: bool = False, deepspeed_plugin=None,
), "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
self.distributed_type = DistributedType.DEEPSPEED
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", **kwargs)
self.backend = "nccl"
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
Expand All @@ -175,7 +177,7 @@ def __init__(self, fp16: bool = None, cpu: bool = False, deepspeed_plugin=None,
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
self.distributed_type = DistributedType.MULTI_GPU
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", **kwargs)
self.backend = "nccl"
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
Expand Down Expand Up @@ -213,7 +215,7 @@ def __init__(self, fp16: bool = None, cpu: bool = False, deepspeed_plugin=None,
"please try exporting rank 0's hostname as MASTER_ADDR"
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend, rank=rank, world_size=size)
torch.distributed.init_process_group(backend, rank=rank, world_size=size, **kwargs)
self.backend = backend
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
Expand Down