Skip to content

Commit

Permalink
fix parallel methods (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Nov 15, 2021
1 parent ef8bcc8 commit 6b434d9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
46 changes: 32 additions & 14 deletions tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,48 @@ def backward(ctx: Any, grad_output: Tensor):

class PreAllreduceSum(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
def forward(ctx, group, input):
ctx.group = group
return input

ctx.num_nodes = get_world_size(ctx.group)
if ctx.num_nodes <= 1:
return input
ctx.input_shape = input.shape
output = torch.empty([ctx.num_nodes, input.numel()], device=input.device, dtype=input.dtype)
tensor_list = [x.contiguous() for x in torch.chunk(output, chunks=ctx.num_nodes, dim=0)]
dist.all_gather(tensor_list=tensor_list, tensor=input.contiguous())
output = output.view(list(input.shape[:0]) + [input.shape[0] * ctx.num_nodes] + list(input.shape[1:]))
return output
@staticmethod
def backward(ctx: Any, grad_output: Tensor):
def backward(ctx, doutput):
if get_world_size(ctx.group) <= 1:
return (None, grad_output)
dinput = torch.clone(grad_output).contiguous()
dist.all_reduce(dinput, op=torch.distributed.ReduceOp.SUM)
return (None, doutput)
dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype)
chunks = [x.contiguous() for x in torch.chunk(doutput.view(ctx.num_nodes, -1), chunks=ctx.num_nodes, dim=0)]
dist.reduce_scatter(output=dinput, input_list=chunks)
return (None, dinput)

class PostAllreduceSum(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
if get_world_size(group) <= 1:
def forward(ctx, group, input):
ctx.group = group
ctx.num_nodes = get_world_size(ctx.group)
if ctx.num_nodes <= 1:
return input
output = torch.clone(input).contiguous()
dist.all_reduce(output, op=torch.distributed.ReduceOp.SUM)
ctx.input_shape = input.shape
ctx.leading_dim = 0
chunks = [x.contiguous() for x in torch.chunk(input, chunks=ctx.num_nodes, dim=ctx.leading_dim)]
assert len(chunks) == ctx.num_nodes
output = torch.empty_like(chunks[0])
dist.reduce_scatter(output=output, input_list=list(chunks))
return output

@staticmethod
def backward(ctx: Any, grad_output: Tensor):
return (None, grad_output)
def backward(ctx, doutput):
if ctx.num_nodes <= 1:
return (None, doutput)
dinput = torch.empty(ctx.input_shape, device=doutput.device, dtype=doutput.dtype)
tensor_list = [x.contiguous() for x in torch.chunk(dinput, chunks=ctx.num_nodes, dim=ctx.leading_dim)]
dist.all_gather(tensor_list=tensor_list, tensor=doutput)
return (None, dinput)


# A2A_TYPE: 0 for skip AllToAll, 1 for standard Pytorch AllToAll, 9 for standard Pytorch AllToAll with Timing
Expand Down
3 changes: 2 additions & 1 deletion tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def named_parameters(self):
def apply_on_expert_fn(self, input, expert_fn, group):
if self.l_zero is None:
self.l_zero = torch.tensor(0, dtype=input.dtype, device=input.device)
result_output = expert_fn(PreAllreduceSum.apply(group, input))
gathered_input = PreAllreduceSum.apply(group, input)
result_output = expert_fn(gathered_input)
result_output = PostAllreduceSum.apply(group, result_output)
return result_output, self.l_zero

Expand Down

0 comments on commit 6b434d9

Please sign in to comment.