|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | # |
| 17 | +from typing import Optional |
17 | 18 |
|
18 | 19 | import torch |
19 | 20 | import torch.distributed as dist |
| 21 | +from torch.distributed import ProcessGroup |
| 22 | +from vllm.distributed.device_communicators.base_device_communicator import \ |
| 23 | + DeviceCommunicatorBase |
20 | 24 |
|
21 | 25 |
|
22 | | -class NPUCommunicator: |
| 26 | +class NPUCommunicator(DeviceCommunicatorBase): |
23 | 27 |
|
24 | | - def __init__(self, group, unique_name=""): |
25 | | - self.group = group |
26 | | - self.unique_name = unique_name |
27 | | - self.rank = dist.get_rank(group) |
28 | | - self.world_size = dist.get_world_size(self.group) |
29 | | - self.ranks = dist.get_process_group_ranks(self.group) |
30 | | - global_rank = dist.get_rank() |
31 | | - self.rank_in_group = dist.get_group_rank(self.group, global_rank) |
32 | | - |
33 | | - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: |
34 | | - dist.all_reduce(x, group=self.group) |
35 | | - return x |
36 | | - |
37 | | - def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1): |
38 | | - # NOTE: We assume that the input tensor is on the same device across |
39 | | - # all the ranks. |
40 | | - # NOTE: `dst` is the local rank of the destination rank. |
41 | | - # Allocate output tensor. |
42 | | - if self.rank_in_group == dst: |
43 | | - gather_list = [ |
44 | | - torch.empty_like(input_) for _ in range(self.world_size) |
45 | | - ] |
46 | | - else: |
47 | | - gather_list = None |
48 | | - # Gather. |
49 | | - dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group) |
50 | | - if self.rank_in_group == dst: |
51 | | - output_tensor = torch.cat(gather_list, dim=dim) |
52 | | - else: |
53 | | - output_tensor = None |
54 | | - return output_tensor |
55 | | - |
56 | | - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: |
57 | | - if dim < 0: |
58 | | - # Convert negative dim to positive. |
59 | | - dim += input_.dim() |
60 | | - input_size = input_.size() |
61 | | - # NOTE: we have to use concat-style all-gather here, |
62 | | - # stack-style all-gather has compatibility issues with |
63 | | - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 |
64 | | - output_size = (input_size[0] * self.world_size, ) + input_size[1:] |
65 | | - # Allocate output tensor. |
66 | | - output_tensor = torch.empty(output_size, |
67 | | - dtype=input_.dtype, |
68 | | - device=input_.device) |
69 | | - # All-gather. |
70 | | - dist.all_gather_into_tensor(output_tensor, input_, group=self.group) |
71 | | - # Reshape |
72 | | - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) |
73 | | - output_tensor = output_tensor.movedim(0, dim) |
74 | | - output_tensor = output_tensor.reshape(input_size[:dim] + |
75 | | - (self.world_size * |
76 | | - input_size[dim], ) + |
77 | | - input_size[dim + 1:]) |
78 | | - return output_tensor |
| 28 | + def __init__(self, |
| 29 | + cpu_group: ProcessGroup, |
| 30 | + device: Optional[torch.device] = None, |
| 31 | + device_group: Optional[ProcessGroup] = None, |
| 32 | + unique_name: str = ""): |
| 33 | + super().__init__(cpu_group, device, device_group, unique_name) |
| 34 | + # init device according to local rank |
| 35 | + local_rank = dist.get_rank(device_group) |
| 36 | + self.device = torch.device(f"npu:{local_rank}") |
0 commit comments