Skip to content

Commit

Permalink
[ADD] support Distributed Data Parallel (#137)
Browse files Browse the repository at this point in the history
## Title
Colossal AI-based Distributed Data Parallel with oslo interface
-

## Description
The purpose of this implementation is to enable DDP in Oslo, with the
reducer method being identical to that of Colossal AI, but adapted to
fit Oslo's interface. To enhance user experience, we replaced
model.backward() with loss.backward() and added model.zero_grad()
temporarily to the code. Any feedback is welcome :)

If you don't use model.zero_grad() there will be unexpected error.

test_data_parallel.py
```python
import os
import torch.multiprocessing as mp

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn
from torch import optim
import torch.distributed as dist

from oslo.torch.distributed.parallel_context import ParallelContext


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

```
test_oslo_data_parallel.py
```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.nn.parallel.data_parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running oslo DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, parallel_context)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    oslo.ready(ddp_model, parallel_context)
    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss = loss_fn(outputs, labels)
    ddp_model.backward(loss)
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

```


![image](https://user-images.githubusercontent.com/26476095/220687694-4236dcaa-ae66-4332-8159-3e206b04df49.png)


![image](https://user-images.githubusercontent.com/26476095/220687852-578584a5-db9a-4a90-ab3e-bdc779bb39a2.png)

-

pytorch DDP
<img width="585" alt="ddp_before_backward"
src="https://user-images.githubusercontent.com/26476095/221404650-2525413c-ce86-44e9-bd53-897ac4077b4a.png">
<img width="577" alt="ddp_after_backward"
src="https://user-images.githubusercontent.com/26476095/221404654-ce1e2d45-9304-4d13-aa83-c5a5f8d06689.png">

Oslo DDP
<img width="610" alt="oslo_before_backward"
src="https://user-images.githubusercontent.com/26476095/221404663-e85a0462-6fd2-4a6d-85a3-7fdcf9a5e9a7.png">
<img width="576" alt="oslo_after_backward"
src="https://user-images.githubusercontent.com/26476095/221404668-8cdee44d-3d76-4d23-adc0-68983ea7b173.png">

By checking the model's parameters, oslo DDP is working as expected.

After Cleaning

![image](https://user-images.githubusercontent.com/26476095/222415778-3358b862-a8c4-416e-9bc1-338d915d5e79.png)
 
Oslo DDP

[oslo-ddp-time.log](https://github.com/EleutherAI/oslo/files/10887632/oslo-ddp-time.log)

Torch DDP

[torch-ddp-time.log](https://github.com/EleutherAI/oslo/files/10887634/torch-ddp-time.log)


## Linked Issues

- resolved #00

---------

Co-authored-by: dongsung kim <kidsung@ip-172-31-42-218.ec2.internal>
Co-authored-by: Hakjin Lee <nijkah@gmail.com>
Co-authored-by: KKIEEK <ethan9867@gmail.com>
  • Loading branch information
4 people authored Mar 10, 2023
1 parent dcad48e commit f129a90
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 3 deletions.
5 changes: 5 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from oslo.torch.nn.parallel.data_parallel.data_parallel import (
DistributedDataParallel,
)
from oslo.torch.nn.parallel.data_parallel.zero import *

__ALL__ = ["DistributedDataParallel", "ZeroRedundancyOptimizer"]
116 changes: 116 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/_reducer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup


class Bucket:
def __init__(
self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup
):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []

def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.all_reduce(self.buffer[: self.offset], group=self.group)

# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.offset = 0
self.callbacks.clear()
self.buffer = torch.zeros_like(self.buffer)

def alloc(self) -> None:

if self.buffer.storage().size() == 0:
self.buffer.storage().resize_(self.buffer.numel())

def free(self) -> None:

assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
self.buffer.storage().resize_(0)

def append(self, tensor: Tensor, callback_fn: Callable):
tensor_size = tensor.numel()
offset = self.offset
self.buffer[offset : offset + tensor_size].copy_(tensor.flatten())
self.offset += tensor_size

# callback will be given the reduced result
if callback_fn is not None:
result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape)
self.callbacks.append(functools.partial(callback_fn, result_view))

@property
def avail_size(self) -> int:
return self.buffer.size(0) - self.offset


class Reducer:
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

@torch.no_grad()
def all_reduce_async(
self,
tensor: Tensor,
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
bucket_size = self._get_bucket_size(tensor.element_size())

if tensor.numel() >= bucket_size:
dist.all_reduce(tensor, group=group)
if callback_fn is not None:
callback_fn(tensor)
return

bucket = self._get_bucket(tensor, group)
if tensor.numel() > bucket.avail_size:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(tensor, callback_fn)

@torch.no_grad()
def flush(self) -> None:
for bucket in self.buckets.values():
bucket.flush()

@torch.no_grad()
def free(self) -> None:
for bucket in self.buckets.values():
bucket.free()

@functools.lru_cache()
def _get_bucket_size(self, element_size: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size)

def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
bucket_size = self._get_bucket_size(tensor.element_size())
self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]
24 changes: 24 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Iterable

import torch


def is_ddp_ignored(p):
return getattr(p, "_ddp_to_ignore", False)


def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing DistributedDataParallel.
Example:
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> set_params_to_ignore(params_to_ignore)
>>> module = DistributedDataParallel(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
179 changes: 179 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import copy
from functools import partial

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"

from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel.utils import (
get_parallel_context,
add_wrapper,
OsloParallelWrapper,
)
from oslo.torch.nn.parallel.data_parallel._reducer import Reducer
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored


def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)


class BackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, inputs):
ctx.module = module
return inputs

@staticmethod
def backward(ctx, *grad_outputs):
ctx.module._pre_backward()
# Enqueue a callback to flush the reducer.
# This callback is triggered after all gradients' calculation is completed.
Variable._execution_engine.queue_callback(ctx.module._post_backward)
return (None,) + grad_outputs


def DistributedDataParallel(
module: nn.Module,
parallel_context: ParallelContext,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
):
ddp = _DistributedDataParallel(
module=module,
parallel_context=parallel_context,
bucket_cap_mb=bucket_cap_mb,
rebuild_bucket=rebuild_bucket,
)

add_wrapper(
module,
mode=ParallelMode.DATA,
wrapper=ddp,
parallel_context=parallel_context,
)
return module


class _DistributedDataParallel(OsloParallelWrapper):
"""Distributed data parallel wrapper for Oslo.
Example:
>>> from oslo.torch.nn.parallel import DistributedDataParallel as DDP
>>> model = torch.nn.Linear(20, 1)
>>> model = DDP(model, parallel_context)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
>>> olso.ready(model, parallel_context)
>>> model.zero_grad()
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> loss.backward()
>>> optimizer.step()
Args:
module (nn.Module): PyTorch module object
parallel_context (ParallelContext): process group object
"""

def __init__(
self,
module: torch.nn.Module,
parallel_context: ParallelContext = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
) -> None:
super().__init__(parallelism_priority=99)
self.module = module

self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
assert parallel_context
self.parallel_context = get_parallel_context(module, parallel_context)
self.dp_world_size = self.parallel_context.get_world_size(ParallelMode.DATA)

self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket

for p in module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))

def parallelize(self):
self._forward = copy.copy(self.module.forward)
self.module.zero_grad = self.zero_grad

def forward(self, *args, **kwargs):
return BackwardFunction.apply(self, self._forward(*args, **kwargs))

def _pre_backward(self):
pass

def _post_backward(self):
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad

def grad_handle(self, p, grad):
if grad.device.type != "cpu":
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(
grad,
group=self.parallel_context.get_group(ParallelMode.DATA),
callback_fn=partial(self._save_grad, p),
)
grad.record_stream(self.comm_stream)
else:
_DistributedDataParallel._save_grad(p, grad)

return empty_grad

else:
# You must model.to('cpu') after oslo.ready() to use cpu.
dist.all_reduce(
grad, group=self.parallel_context.get_cpu_group(ParallelMode.DATA)
)
return grad

@staticmethod
def _save_grad(p, grad):
if hasattr(p, "_saved_grad"):
p._saved_grad.add_(grad)
else:
p._saved_grad = grad

def zero_grad(self, set_to_none: bool = False) -> None:
super().zero_grad(set_to_none=True)
for p in self.module.parameters():
if getattr(p, "_saved_grad", None) is not None:
if set_to_none:
p._saved_grad = None
else:
if p._saved_grad.grad_fn is not None:
p._saved_grad.detach_()
else:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()
2 changes: 1 addition & 1 deletion oslo/torch/nn/parallel/data_parallel/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ZeroRedundancyOptimizer,
)

__all__ = ["ZeroRedundancyOptimizer"]
__ALL__ = ["ZeroRedundancyOptimizer"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ZeroRedundancyOptimizer,
)

__all__ = ["ZeroRedundancyOptimizer"]
__ALL__ = ["ZeroRedundancyOptimizer"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .parameter_store import ParameterStore
from .tensor_store import TensorBucket

_all__ = ["BucketStore", "GradientStore", "ParameterStore", "TensorBucket"]
__ALL__ = ["BucketStore", "GradientStore", "ParameterStore", "TensorBucket"]

0 comments on commit f129a90

Please sign in to comment.