Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions vllm_ascend/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.forward_context import get_forward_context
from vllm.distributed.parallel_state import get_dp_group


class NPUCommunicator(DeviceCommunicatorBase):
Expand Down Expand Up @@ -73,3 +75,47 @@ def all_to_all(self,
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
dp_group = get_dp_group()
dp_rank = dp_group.rank_in_group
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)

start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[
dp_rank - 1]
end = cu_tokens_across_dp_cpu[dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(dp_group.world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
dp_group.broadcast(buffer[start:end, :], idx)

return buffer

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu

hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
dp_group = get_dp_group()
dp_rank = dp_group.rank_in_group
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[
dp_rank - 1]
end = cu_tokens_across_dp_cpu[dp_rank]

all_hidden_states = dp_group.all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
Loading