Skip to content

Commit

Permalink
Enable ZeRO3 allgather for multiple dtypes (#4647)
Browse files Browse the repository at this point in the history
This PR addresses an error reported in #4295.
When parameters in multiple data types are given, DeepSpeed performs
allgather for each data type.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent 01af3e1 commit b8e1664
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 33 deletions.
97 changes: 65 additions & 32 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
import itertools
from typing import List
from collections import defaultdict
import logging
import torch
from torch import Tensor
Expand All @@ -22,7 +23,7 @@

from deepspeed.utils import groups
import deepspeed
from ..utils import get_only_unique_item, see_memory_usage
from ..utils import see_memory_usage
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -656,6 +657,16 @@ def wait(self) -> None:
self.complete = True


class MultipleAllGatherHandles:

def __init__(self, handles: List[AllGatherCoalescedHandle]):
self.handles = handles

def wait(self) -> None:
for handle in self.handles:
handle.wait()


class QuantizationInfo:
# a placeholder object to store all quant related vars used in handles
def __init__(self) -> None:
Expand Down Expand Up @@ -1050,6 +1061,42 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
param_list = [cls]
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group):
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=dtype,
device=get_accelerator().current_device_name(),
requires_grad=False)

partitions: List[Parameter] = []
for i in range(world_size):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
instrument_w_nvtx(
torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
use_secondary_tensor = False
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
#Fix get_partition_dp_group(params[0]))

return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=partitions,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
forward=forward,
)

@instrument_w_nvtx
def all_gather_coalesced(params: Iterable[Parameter],
forward: bool = True,
Expand Down Expand Up @@ -1146,42 +1193,28 @@ def all_gather_coalesced(params: Iterable[Parameter],
return AllGatherHandle(handle, param, quantization=quant_info)

else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(
_all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group))

if params[0].ds_secondary_tensor is not None and not forward:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
return MultipleAllGatherHandles(handles)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=get_only_unique_item(p.ds_tensor.dtype
for p in params) if not quantize else torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)
if not quantize:
partitions: List[Parameter] = []
for i in range(world_size):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
instrument_w_nvtx(torch.cat)(
[p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
instrument_w_nvtx(
torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
#Fix get_partition_dp_group(params[0]))
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=partitions,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
forward=forward,
)
else:
if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down

0 comments on commit b8e1664

Please sign in to comment.