|
| 1 | +from abc import ABC, abstractmethod |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch_npu |
| 5 | +from vllm.forward_context import ForwardContext, get_forward_context |
| 6 | +from vllm.utils import direct_register_custom_op |
| 7 | + |
| 8 | + |
| 9 | +class MoECommMethod(ABC): |
| 10 | + """Base class for MoE communication methods.""" |
| 11 | + |
| 12 | + def __init__( |
| 13 | + self, |
| 14 | + device: torch.device, |
| 15 | + dtype: torch.dtype, |
| 16 | + top_k_num: int, |
| 17 | + global_num_experts: int, |
| 18 | + ): |
| 19 | + self.device = device |
| 20 | + self.dtype = dtype |
| 21 | + self.top_k_num = top_k_num |
| 22 | + self.global_num_experts = global_num_experts |
| 23 | + |
| 24 | + @abstractmethod |
| 25 | + def _pre_process( |
| 26 | + self, |
| 27 | + hidden_states: torch.Tensor, |
| 28 | + topk_ids: torch.Tensor, |
| 29 | + topk_weights: torch.Tensor, |
| 30 | + expert_map: torch.Tensor, |
| 31 | + num_experts: int, |
| 32 | + ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 33 | + """Pre-process before MLP.""" |
| 34 | + pass |
| 35 | + |
| 36 | + @abstractmethod |
| 37 | + def _post_process(self, mlp_output: torch.Tensor, |
| 38 | + hidden_states: torch.Tensor) -> None: |
| 39 | + """Post-process after MLP.""" |
| 40 | + pass |
| 41 | + |
| 42 | + |
| 43 | +class DummyCommImpl(MoECommMethod): |
| 44 | + |
| 45 | + def _pre_process( |
| 46 | + self, |
| 47 | + hidden_states: torch.Tensor, |
| 48 | + topk_ids: torch.Tensor, |
| 49 | + topk_weights: torch.Tensor, |
| 50 | + expert_map: torch.Tensor, |
| 51 | + num_experts: int, |
| 52 | + ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 53 | + return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights, |
| 54 | + expert_map, num_experts) |
| 55 | + |
| 56 | + def _post_process(self, mlp_output: torch.Tensor, |
| 57 | + hidden_states: torch.Tensor) -> None: |
| 58 | + """Dummy implementation that does nothing.""" |
| 59 | + pass |
| 60 | + |
| 61 | + |
| 62 | +class AllGatherCommImpl(MoECommMethod): |
| 63 | + """This implementation is for the scenarios listed below: |
| 64 | + 1. `enable_expert_parallel=True`. |
| 65 | +
|
| 66 | + Note that this implementation purely consists of native PyTorch ops |
| 67 | + and does not use any NPU-specific ops. So the performance may not be optimal. |
| 68 | + But it is a good fallback for scenarios where NPU-specific ops are not available. |
| 69 | + """ |
| 70 | + |
| 71 | + def _pre_process( |
| 72 | + self, |
| 73 | + hidden_states: torch.Tensor, |
| 74 | + topk_ids: torch.Tensor, |
| 75 | + topk_weights: torch.Tensor, |
| 76 | + expert_map: torch.Tensor, |
| 77 | + num_experts: int, |
| 78 | + ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 79 | + num_tokens = hidden_states.shape[0] |
| 80 | + |
| 81 | + # Generate token indices and flatten |
| 82 | + token_indices = torch.arange(num_tokens, |
| 83 | + device=self.device, |
| 84 | + dtype=torch.int64) |
| 85 | + token_indices = (token_indices.unsqueeze(1).expand( |
| 86 | + -1, self.top_k_num).reshape(-1)) |
| 87 | + |
| 88 | + # Flatten token-to-expert mappings and map to local experts |
| 89 | + weights_flat = topk_weights.view(-1) |
| 90 | + experts_flat = topk_ids.view(-1) |
| 91 | + local_experts_flat = (expert_map[experts_flat] |
| 92 | + if expert_map is not None else experts_flat) |
| 93 | + |
| 94 | + # Filter valid token-expert pairs |
| 95 | + mask = local_experts_flat != -1 |
| 96 | + filtered_weights = torch.where(mask, weights_flat, |
| 97 | + torch.zeros_like(weights_flat)).to( |
| 98 | + self.dtype) |
| 99 | + filtered_experts = torch.where( |
| 100 | + mask, |
| 101 | + local_experts_flat, |
| 102 | + torch.full_like(local_experts_flat, num_experts), |
| 103 | + ).to(topk_ids.dtype) |
| 104 | + self.mask = mask |
| 105 | + |
| 106 | + # Sort by local expert IDs |
| 107 | + sort_indices = torch.argsort(filtered_experts.view(torch.float32)) |
| 108 | + self.sorted_token_indices = token_indices[sort_indices] |
| 109 | + self.sorted_weights = filtered_weights[sort_indices] |
| 110 | + |
| 111 | + # Compute token counts with minlength of num_experts |
| 112 | + # This is equivalent to but faster than: |
| 113 | + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] |
| 114 | + token_counts = torch.zeros(num_experts + 1, |
| 115 | + device=self.device, |
| 116 | + dtype=torch.int64) |
| 117 | + ones = torch.ones_like(filtered_experts, dtype=torch.int64) |
| 118 | + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) |
| 119 | + token_counts = token_counts[:num_experts] |
| 120 | + expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) |
| 121 | + |
| 122 | + # Rearrange hidden_states |
| 123 | + permuted_hidden_states = hidden_states[self.sorted_token_indices] |
| 124 | + |
| 125 | + group_list_type = 0 |
| 126 | + |
| 127 | + return permuted_hidden_states, expert_tokens, group_list_type |
| 128 | + |
| 129 | + def _post_process(self, mlp_output: torch.Tensor, |
| 130 | + hidden_states: torch.Tensor) -> None: |
| 131 | + weighted_down_out = mlp_output * self.sorted_weights.unsqueeze(1) |
| 132 | + |
| 133 | + final_hidden_states = torch.zeros_like(hidden_states) |
| 134 | + |
| 135 | + # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] |
| 136 | + # This created multiple NaN and index_add_ will mix them up which harms accuracy |
| 137 | + # remove this mask and filter after it being fixed |
| 138 | + num_valid_tokens = self.mask.sum() |
| 139 | + valid_token_mask = (torch.arange( |
| 140 | + 0, self.sorted_token_indices.shape[0], |
| 141 | + device=self.device).unsqueeze(1) < num_valid_tokens) |
| 142 | + valid_output = torch.where(valid_token_mask, weighted_down_out, |
| 143 | + torch.zeros_like(weighted_down_out)).to( |
| 144 | + self.dtype) |
| 145 | + final_hidden_states.index_add_(0, self.sorted_token_indices, |
| 146 | + valid_output) |
| 147 | + |
| 148 | + hidden_states[:] = final_hidden_states |
| 149 | + |
| 150 | + |
| 151 | +class AllReduceCommImpl(MoECommMethod): |
| 152 | + """This implementation is for the scenarios listed below: |
| 153 | + 1. `enable_expert_parallel=False`. |
| 154 | + 2. If `npu_moe_init_routing_v2` is available, we will support `enable_expert_parallel=True`, |
| 155 | + and this implementation will become the default one, changing the name to `AllGather` at |
| 156 | + the same time. |
| 157 | + """ |
| 158 | + |
| 159 | + def _pre_process( |
| 160 | + self, |
| 161 | + hidden_states: torch.Tensor, |
| 162 | + topk_ids: torch.Tensor, |
| 163 | + topk_weights: torch.Tensor, |
| 164 | + expert_map: torch.Tensor, # noqa: F841 |
| 165 | + num_experts: int, |
| 166 | + ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 167 | + num_tokens = hidden_states.shape[0] |
| 168 | + |
| 169 | + self.topk_weights = topk_weights |
| 170 | + self.topk_ids = topk_ids |
| 171 | + |
| 172 | + # 1. Prepare row indices for routing |
| 173 | + row_idx_len = num_tokens * self.top_k_num |
| 174 | + row_idx = torch.arange(row_idx_len, |
| 175 | + dtype=torch.int32, |
| 176 | + device=self.device) |
| 177 | + row_idx = row_idx.view(self.top_k_num, -1).permute(1, 0).contiguous() |
| 178 | + |
| 179 | + # 2. Initial routing to expand tokens and experts |
| 180 | + permuted_hidden_states, expanded_row_idx, expanded_expert_idx = ( |
| 181 | + torch_npu.npu_moe_init_routing( |
| 182 | + hidden_states, |
| 183 | + row_idx=row_idx, |
| 184 | + expert_idx=topk_ids, |
| 185 | + active_num=num_tokens, |
| 186 | + )) |
| 187 | + # NOTE: Currently, V2 produces incorrect accuracy and weaker performance than V1 |
| 188 | + # first_expert_idx = 0 |
| 189 | + # if expert_map is not None: |
| 190 | + # first_expert_idx = torch.nonzero(expert_map != -1, as_tuple=False)[0].item() |
| 191 | + # last_expert_idx = first_expert_idx + num_experts |
| 192 | + # permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( |
| 193 | + # torch_npu.npu_moe_init_routing_v2( |
| 194 | + # hidden_states, |
| 195 | + # topk_ids, |
| 196 | + # active_num=num_tokens * self.top_k_num, |
| 197 | + # expert_num=self.global_num_experts, |
| 198 | + # expert_tokens_num_type=1, # Only support `count` mode now |
| 199 | + # expert_tokens_num_flag=True, # Output `expert_tokens` |
| 200 | + # active_expert_range=[first_expert_idx, last_expert_idx], |
| 201 | + # quant_mode=-1, |
| 202 | + # ) |
| 203 | + # ) |
| 204 | + self.expanded_row_idx = expanded_row_idx |
| 205 | + permuted_hidden_states = permuted_hidden_states |
| 206 | + |
| 207 | + # 3. Compute expert tokens |
| 208 | + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
| 209 | + expanded_expert_idx, num_experts).to(torch.int64) |
| 210 | + # NOTE: This is also for npu_moe_init_routing_v2 |
| 211 | + # expert_tokens = torch.cumsum(expert_tokens, 0) |
| 212 | + |
| 213 | + group_list_type = 0 |
| 214 | + |
| 215 | + return permuted_hidden_states, expert_tokens, group_list_type |
| 216 | + |
| 217 | + def _post_process(self, mlp_output: torch.Tensor, |
| 218 | + hidden_states: torch.Tensor) -> None: |
| 219 | + hidden_states[:] = torch_npu.npu_moe_finalize_routing( |
| 220 | + mlp_output, |
| 221 | + skip1=None, |
| 222 | + skip2=None, |
| 223 | + bias=None, |
| 224 | + scales=self.topk_weights, |
| 225 | + expanded_src_to_dst_row=self.expanded_row_idx, |
| 226 | + export_for_source_row=self.topk_ids, |
| 227 | + # NOTE: For npu_moe_init_routing_v2 |
| 228 | + # drop_pad_mode=2, |
| 229 | + ) |
| 230 | + |
| 231 | + |
| 232 | +def moe_comm_pre_process( |
| 233 | + hidden_states: torch.Tensor, |
| 234 | + topk_ids: torch.Tensor, |
| 235 | + topk_weights: torch.Tensor, |
| 236 | + expert_map: torch.Tensor, |
| 237 | + num_experts: int, |
| 238 | +) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 239 | + forward_context: ForwardContext = get_forward_context() |
| 240 | + self = forward_context.moe_comm_method |
| 241 | + return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map, |
| 242 | + num_experts) |
| 243 | + |
| 244 | + |
| 245 | +def moe_comm_pre_process_fake( |
| 246 | + hidden_states: torch.Tensor, |
| 247 | + topk_ids: torch.Tensor, |
| 248 | + topk_weights: torch.Tensor, |
| 249 | + expert_map: torch.Tensor, |
| 250 | + num_experts: int, |
| 251 | +) -> tuple[torch.Tensor, torch.Tensor, int]: |
| 252 | + top_k_num = topk_ids.shape[1] |
| 253 | + permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0) |
| 254 | + expert_tokens = torch.zeros((num_experts, ), |
| 255 | + dtype=torch.int64, |
| 256 | + device=hidden_states.device) |
| 257 | + group_list_type = 0 |
| 258 | + return permuted_hidden_states, expert_tokens, group_list_type |
| 259 | + |
| 260 | + |
| 261 | +def moe_comm_post_process(mlp_output: torch.Tensor, |
| 262 | + hidden_states: torch.Tensor) -> None: |
| 263 | + forward_context: ForwardContext = get_forward_context() |
| 264 | + self = forward_context.moe_comm_method |
| 265 | + self._post_process(mlp_output, hidden_states) |
| 266 | + return |
| 267 | + |
| 268 | + |
| 269 | +direct_register_custom_op( |
| 270 | + op_name="moe_comm_pre_process", |
| 271 | + op_func=moe_comm_pre_process, |
| 272 | + mutates_args=[], |
| 273 | + fake_impl=moe_comm_pre_process_fake, |
| 274 | + dispatch_key="PrivateUse1", |
| 275 | +) |
| 276 | + |
| 277 | +direct_register_custom_op( |
| 278 | + op_name="moe_comm_post_process", |
| 279 | + op_func=moe_comm_post_process, |
| 280 | + mutates_args=["hidden_states"], |
| 281 | + fake_impl=lambda x, y: None, # No-op for fake implementation |
| 282 | + dispatch_key="PrivateUse1", |
| 283 | +) |
0 commit comments