Skip to content

Commit

Permalink
Use deepspeed.comm instead of torch.distributed (#5225)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi authored Mar 4, 2024
1 parent e6e8c13 commit acf0739
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
6 changes: 1 addition & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
class _AllToAll(torch.autograd.Function):

@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
input: Tensor) -> Tensor: # type: ignore
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import torch
from torch import Tensor
from deepspeed import comm as dist
# NOTE: Use torch.distributed's ProcessGroup class until we have our own.
from torch.distributed import ProcessGroup, all_to_all_single
from deepspeed.comm import ProcessGroup, all_to_all_single
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import instrument_w_nvtx
from deepspeed.ops import op_builder
Expand Down

0 comments on commit acf0739

Please sign in to comment.