| 
 | 1 | +import unittest  | 
 | 2 | +from unittest.mock import MagicMock, Mock, patch  | 
 | 3 | + | 
 | 4 | +import torch  | 
 | 5 | +import torch.distributed as dist  | 
 | 6 | + | 
 | 7 | +from vllm_ascend.distributed.communicator import NPUCommunicator  | 
 | 8 | + | 
 | 9 | + | 
 | 10 | +class TestNPUCommunicator(unittest.TestCase):  | 
 | 11 | + | 
 | 12 | +    @patch("vllm.config.get_current_vllm_config", return_value=None)  | 
 | 13 | +    @patch("torch.npu.current_device", return_value=MagicMock())  | 
 | 14 | +    @patch("torch.npu.set_device", return_value=MagicMock())  | 
 | 15 | +    @patch("torch.distributed.get_process_group_ranks",  | 
 | 16 | +           return_value={  | 
 | 17 | +               0: 0,  | 
 | 18 | +               1: 1  | 
 | 19 | +           })  | 
 | 20 | +    @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})  | 
 | 21 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 22 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 23 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 24 | +    @patch("torch.distributed.get_backend", return_value="hccl")  | 
 | 25 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 26 | +    @patch("torch.distributed.get_world_size", return_value=2)  | 
 | 27 | +    @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])  | 
 | 28 | +    @patch("torch.npu.device")  | 
 | 29 | +    def test_all_to_all_with_sizes(self, *_):  | 
 | 30 | + | 
 | 31 | +        def patched_all_to_all(output_tensor_list,  | 
 | 32 | +                               input_tensor_list,  | 
 | 33 | +                               group=None,  | 
 | 34 | +                               async_op=False):  | 
 | 35 | +            output_tensor_list[:] = ([  | 
 | 36 | +                torch.tensor([10, 20]),  | 
 | 37 | +                torch.tensor([50, 60])  | 
 | 38 | +            ])  | 
 | 39 | + | 
 | 40 | +        torch.distributed.all_to_all = patched_all_to_all  | 
 | 41 | + | 
 | 42 | +        scatter_sizes = [2, 2]  | 
 | 43 | +        gather_sizes = [2, 2]  | 
 | 44 | +        input_ = torch.tensor([10, 20, 30, 40])  | 
 | 45 | + | 
 | 46 | +        comm = NPUCommunicator(cpu_group=dist.group.WORLD)  | 
 | 47 | + | 
 | 48 | +        output = comm.all_to_all(input_,  | 
 | 49 | +                                 scatter_sizes=scatter_sizes,  | 
 | 50 | +                                 gather_sizes=gather_sizes)  | 
 | 51 | + | 
 | 52 | +        assert output.tolist() == [10, 20, 50, 60]  | 
 | 53 | + | 
 | 54 | +    @patch("vllm.config.get_current_vllm_config", return_value=None)  | 
 | 55 | +    @patch("torch.npu.current_device", return_value=MagicMock())  | 
 | 56 | +    @patch("torch.npu.set_device", return_value=MagicMock())  | 
 | 57 | +    @patch("torch.distributed.get_process_group_ranks",  | 
 | 58 | +           return_value={  | 
 | 59 | +               0: 0,  | 
 | 60 | +               1: 1  | 
 | 61 | +           })  | 
 | 62 | +    @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})  | 
 | 63 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 64 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 65 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 66 | +    @patch("torch.distributed.get_backend", return_value="hccl")  | 
 | 67 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 68 | +    @patch("torch.distributed.get_world_size", return_value=2)  | 
 | 69 | +    @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])  | 
 | 70 | +    @patch("torch.npu.device")  | 
 | 71 | +    def test_all_to_all_without_sizes(self, *_):  | 
 | 72 | + | 
 | 73 | +        def patched_all_to_all(output_tensor_list,  | 
 | 74 | +                               input_tensor_list,  | 
 | 75 | +                               group=None,  | 
 | 76 | +                               async_op=False):  | 
 | 77 | +            output_tensor_list[:] = ([  | 
 | 78 | +                torch.tensor([[10, 20]]),  | 
 | 79 | +                torch.tensor([[50, 60]])  | 
 | 80 | +            ])  | 
 | 81 | + | 
 | 82 | +        torch.distributed.all_to_all = patched_all_to_all  | 
 | 83 | + | 
 | 84 | +        input_ = torch.tensor([[10, 20], [30, 40]])  | 
 | 85 | + | 
 | 86 | +        comm = NPUCommunicator(cpu_group=dist.group.WORLD)  | 
 | 87 | +        output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)  | 
 | 88 | + | 
 | 89 | +        assert output.tolist() == [[10, 20], [50, 60]]  | 
 | 90 | + | 
 | 91 | +    @patch("vllm.config.get_current_vllm_config", return_value=None)  | 
 | 92 | +    @patch("torch.npu.current_device", return_value=MagicMock())  | 
 | 93 | +    @patch("torch.npu.set_device", return_value=MagicMock())  | 
 | 94 | +    @patch("torch.distributed.get_process_group_ranks",  | 
 | 95 | +           return_value={  | 
 | 96 | +               0: 0,  | 
 | 97 | +               1: 1  | 
 | 98 | +           })  | 
 | 99 | +    @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})  | 
 | 100 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 101 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 102 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 103 | +    @patch("torch.distributed.get_backend", return_value="hccl")  | 
 | 104 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 105 | +    @patch("torch.distributed.get_world_size", return_value=2)  | 
 | 106 | +    @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])  | 
 | 107 | +    @patch("torch.npu.device")  | 
 | 108 | +    def test_dispatch(self, *_):  | 
 | 109 | +        comm = NPUCommunicator(cpu_group=dist.group.WORLD)  | 
 | 110 | +        comm.all2all_manager = Mock()  | 
 | 111 | +        hidden_states = torch.randn(2, 4, 8)  | 
 | 112 | +        router_logits = torch.randn(2, 4, 2)  | 
 | 113 | + | 
 | 114 | +        mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))  | 
 | 115 | +        comm.all2all_manager.dispatch.return_value = mock_dispatch_result  | 
 | 116 | + | 
 | 117 | +        result_hidden, result_logits = comm.dispatch(hidden_states,  | 
 | 118 | +                                                     router_logits)  | 
 | 119 | + | 
 | 120 | +        assert torch.allclose(result_hidden, mock_dispatch_result[0])  | 
 | 121 | +        assert torch.allclose(result_logits, mock_dispatch_result[1])  | 
 | 122 | + | 
 | 123 | +        comm.all2all_manager.dispatch.assert_called_once_with(  | 
 | 124 | +            hidden_states, router_logits)  | 
 | 125 | + | 
 | 126 | +    @patch("vllm.config.get_current_vllm_config", return_value=None)  | 
 | 127 | +    @patch("torch.npu.current_device", return_value=MagicMock())  | 
 | 128 | +    @patch("torch.npu.set_device", return_value=MagicMock())  | 
 | 129 | +    @patch("torch.distributed.get_process_group_ranks",  | 
 | 130 | +           return_value={  | 
 | 131 | +               0: 0,  | 
 | 132 | +               1: 1  | 
 | 133 | +           })  | 
 | 134 | +    @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})  | 
 | 135 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 136 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 137 | +    @patch("torch.distributed.is_initialized", return_value=True)  | 
 | 138 | +    @patch("torch.distributed.get_backend", return_value="hccl")  | 
 | 139 | +    @patch("torch.distributed.get_rank", return_value=1)  | 
 | 140 | +    @patch("torch.distributed.get_world_size", return_value=2)  | 
 | 141 | +    @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])  | 
 | 142 | +    @patch("torch.npu.device")  | 
 | 143 | +    def test_combine(self, *_):  | 
 | 144 | +        comm = NPUCommunicator(cpu_group=dist.group.WORLD)  | 
 | 145 | +        comm.all2all_manager = Mock()  | 
 | 146 | +        hidden_states = torch.randn(2, 4, 8)  | 
 | 147 | + | 
 | 148 | +        mock_combine_result = torch.randn(2, 4, 8)  | 
 | 149 | +        comm.all2all_manager.combine.return_value = mock_combine_result  | 
 | 150 | + | 
 | 151 | +        result = comm.combine(hidden_states)  | 
 | 152 | + | 
 | 153 | +        assert torch.allclose(result, mock_combine_result)  | 
 | 154 | + | 
 | 155 | +        comm.all2all_manager.combine.assert_called_once_with(hidden_states)  | 
0 commit comments