diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 7f8d5f75c6..c014508ae2 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -683,3 +683,40 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla, self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() mock_page_attention_mla.assert_called_once() + + @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill") + @patch("torch_npu._npu_reshape_and_cache") + def test_forward_without_graph(self, _, mock_forward_prefill): + self.impl.running_in_graph = False + self.impl.torchair_graph_enabled = False + + num_tokens = 100 + num_blocks = 256 + block_size = 4 + rotary_emb_return_value = (torch.randn(num_tokens, 16, + self.impl.kv_lora_rank), + torch.randn(0, 1, self.impl.kv_lora_rank)) + self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value + self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn( + 1, num_blocks, 128) + + hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank) + hidden_states_or_kv_c_normed = torch.randn(num_tokens, + self.impl.kv_lora_rank) + k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim) + kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.kv_lora_rank), + torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.qk_rope_head_dim)) + output = torch.randn(num_tokens, self.impl.num_heads, + self.impl.v_head_dim) + + metadata = MagicMock() + metadata.num_decodes = 0 + metadata.num_prefills = num_tokens + mock_forward_prefill.return_value = torch.randn( + 0, self.impl.num_heads * self.impl.v_head_dim) + result = self.impl.forward(None, hidden_states_or_q_c, + hidden_states_or_kv_c_normed, k_pe, + kv_cache, metadata, output, False) + self.assertEqual(result.shape[0], num_tokens) diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py new file mode 100644 index 0000000000..b2075b667c --- /dev/null +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -0,0 +1,70 @@ +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all + + +class TestAscendW8A8FusedMoEMethod(TestBase): + + def setUp(self): + self.hidden_size = 128 + self.num_tokens = 128 + self.placeholder = torch.randn(self.num_tokens, self.hidden_size) + + @patch("torch.distributed.all_to_all_single") + @patch("torch_npu.npu_moe_re_routing") + @patch("torch_npu.npu_grouped_matmul") + @patch("torch_npu.npu_swiglu") + @patch("torch_npu.npu_dynamic_quant") + @patch("torch_npu.npu_moe_finalize_routing") + @patch("torch_npu.npu_moe_init_routing") + def test_fused_experts_with_all2all(self, mock_moe_init_routing, + mock_moe_finalize_routing, + mock_dynamic_quant, mock_swiglu, + mock_grouped_matmul, + mock_moe_re_routing, + mock_all_to_all_single): + expert_map = MagicMock() + ep_group = MagicMock() + placeholder_int8 = torch.randint(0, + 100, + (self.num_tokens, self.hidden_size), + dtype=torch.int8) + placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32) + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + mock_moe_init_routing.return_value = ( + placeholder_int8, + placeholder_ones, + placeholder_ones, + ) + mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder, + torch.randint(0, + 100, + (self.num_tokens, ), + dtype=torch.int32), + self.placeholder) + mock_grouped_matmul.return_value = self.placeholder + mock_swiglu.return_value = self.placeholder + mock_dynamic_quant.return_value = ( + placeholder_int8, + torch.randn(self.num_tokens), + ) + mock_moe_finalize_routing.return_value = self.placeholder + + fused_experts_with_all2all( + hidden_states=self.placeholder, + w1=self.placeholder, + w1_scale=self.placeholder, + w2=self.placeholder, + w2_scale=self.placeholder, + topk_weights=self.placeholder, + topk_ids=self.placeholder, + top_k=8, + expert_map=expert_map, + ep_group=ep_group, + log2phy=None, + global_redundant_expert_num=256, + ) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index b8fd24e7d1..7466c539b6 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -47,6 +47,8 @@ def __init__(self, vllm_config): self.expert_map_path = additional_config.get("expert_map_path", None) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) + self.enable_prefill_optimizations = additional_config.get( + "enable_prefill_optimizations", False) class TorchairGraphConfig: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b2b3ad0e59..8a6b714552 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -587,6 +587,8 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.enable_prefill_optimizations = \ + ascend_config.enable_prefill_optimizations and not self.torchair_graph_enabled # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -601,6 +603,8 @@ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if hasattr(self, "running_in_graph") and not self.running_in_graph: + return x MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB npu_prefetch(self.o_proj.weight, x, @@ -871,14 +875,7 @@ def _forward_prefill( ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - return self.o_proj(attn_output, is_prefill=True)[0] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - return self.o_proj(attn_output, is_prefill=True)[0] + return attn_output def exec_kv( self, @@ -1208,6 +1205,12 @@ def forward( key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=attn_metadata.slot_mapping) + if not self.running_in_graph: + o_proj_input_shape = (num_actual_toks, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy @@ -1218,11 +1221,12 @@ def forward( attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: + current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): - output[num_decode_tokens:] = output_prefill - current_ms_metadata.after_comm_event.record() + current_ms_metadata.before_comm_event.wait() + o_proj_input[num_decode_tokens:] = output_prefill else: - output[num_decode_tokens:] = output_prefill + o_proj_input[num_decode_tokens:] = output_prefill if has_decode: if self.running_in_graph: @@ -1239,9 +1243,34 @@ def forward( current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): - output[:num_decode_tokens] = output_decode + o_proj_input[:num_decode_tokens] = output_decode current_ms_metadata.after_comm_event.record() else: - output[:num_decode_tokens] = output_decode + o_proj_input[:num_decode_tokens] = output_decode + current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + if current_ms_metadata is None: + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + + _o = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_prefill_optimizations)[0] + output[...] = _o + else: + with torch.npu.stream(current_ms_metadata.comm_stream): + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_prefill_optimizations)[0] + current_ms_metadata.after_comm_event.record() + del o_proj_input return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e1c2b1cc68..85def52dc6 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): def forward( self, input_, - is_prefill=True + is_prefill=True, + is_force_scatter=False ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: if self.input_is_parallel: input_parallel = input_ @@ -160,7 +161,13 @@ def forward( input_parallel, bias=bias_) if self.reduce_results and self.tp_size > 1: - if not is_prefill and output_parallel.shape[0] % self.tp_size == 0: + num_tokens = output_parallel.shape[0] + if is_force_scatter and num_tokens % self.tp_size: + output_parallel = nn.functional.pad( + output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) + if is_force_scatter or (not is_prefill + and output_parallel.shape[0] % self.tp_size + == 0): output = tensor_model_parallel_reduce_scatter(output_parallel, dim=0) else: @@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear): def forward( self, input_, - is_prefill=True + is_prefill=True, + is_force_scatter=False ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: if self.input_is_parallel: input_parallel = input_ @@ -347,13 +355,16 @@ def __init__( reduce_results = not self.all_reduce_merge intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + enable_prefill_optimizations = \ + ascend_config.enable_prefill_optimizations and not ascend_config.torchair_graph_config.enabled self.shared_experts = CustomDeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=reduce_results, - force_replicate=self.enable_multistream_moe, + force_replicate=self.enable_multistream_moe + or enable_prefill_optimizations, prefix=f"{prefix}.shared_experts", ) else: @@ -447,9 +458,9 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta @@ -462,6 +473,8 @@ def __init__( self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_mla = \ ascend_config.torchair_graph_config.enable_multistream_mla + self.enable_prefill_optimizations = \ + ascend_config.enable_prefill_optimizations and not self.torchair_graph_enabled if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, @@ -501,8 +514,9 @@ def __init__( prefix=f"{prefix}.kv_b_proj") if (config.n_routed_experts is not None and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 and - ascend_config.torchair_graph_config.enable_multistream_moe): + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.torchair_graph_config.enable_multistream_moe + or self.enable_prefill_optimizations)): self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( self.num_heads * self.v_head_dim, self.hidden_size, @@ -596,13 +610,33 @@ def forward( output = output.view(-1, output_shape[-1]) return output else: - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + if self.enable_prefill_optimizations and self.debug_layer_idx > 3 and self.debug_layer_idx < 61: + hidden_states_or_q_c = get_tp_group().all_gather( + hidden_states_or_q_c, 0) + kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + else: + num_tokens = hidden_states_or_q_c.shape[0] + + kv_c, k_pe = kv_no_split.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.enable_prefill_optimizations or self.debug_layer_idx < 3: + output_shape = hidden_states.shape + else: + num_tokens = hidden_states_or_q_c.shape[0] + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, - output_shape=hidden_states.shape) + output_shape=output_shape) class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): @@ -677,6 +711,9 @@ def __init__( eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + self.enable_prefill_optimizations = \ + ascend_config.enable_prefill_optimizations and not ascend_config.torchair_graph_config.enabled def forward( self, @@ -731,6 +768,17 @@ def forward( # first layer. residual *= 1. / self.routed_scaling_factor + tp_size = get_tensor_model_parallel_world_size() + if self.enable_prefill_optimizations and ( + self.layer_idx == 3 or self.layer_idx == 61) and tp_size > 1: + num_tokens, _ = residual.shape + if num_tokens % tp_size: + residual = nn.functional.pad(residual, + (0, 0, 0, -num_tokens % tp_size)) + chunk_residual = torch.tensor_split(residual, tp_size, dim=0) + tp_rank = get_tensor_model_parallel_rank() + residual = chunk_residual[tp_rank] + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -756,6 +804,21 @@ def forward( dim=0) residual = tensor_model_parallel_all_gather(residual, dim=0) + # for last layer of main model and mtp layer. + if self.enable_prefill_optimizations and self.layer_idx >= 60 and tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + residual = get_tp_group().all_gather(residual, 0) + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + else: + num_tokens = hidden_states.shape[0] + + if num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:num_tokens] + residual = residual[:num_tokens] + return hidden_states, residual diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index b2b1ab9bb3..cdb000177b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1265,6 +1265,8 @@ def __init__( self.enable_multistream_moe = \ ascend_config.torchair_graph_config.enable_multistream_moe and \ self.torchair_graph_enabled + self.enable_prefill_optimizations = \ + ascend_config.enable_prefill_optimizations and not self.torchair_graph_enabled if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1394,22 +1396,24 @@ def forward(self, else: # TODO: Determine if we can remove the padding padding_size = tp_size - if num_tokens < padding_size: + if num_tokens < padding_size and not self.enable_prefill_optimizations: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, padding_size - num_tokens)) router_logits = nn.functional.pad( router_logits, (0, 0, 0, padding_size - num_tokens)) if tp_size > 1: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) tp_rank = get_tensor_model_parallel_rank() - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] + if not self.enable_prefill_optimizations: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] + + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) mc2_mask = chunk_mc2_mask[tp_rank] if self.dp_size > 1: @@ -1476,7 +1480,7 @@ def forward(self, if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast - ] and not replace_allreduce): + ] and not replace_allreduce and not self.enable_prefill_optimizations): if tp_size > 1: dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) @@ -1486,7 +1490,7 @@ def forward(self, final_hidden_states = e_hidden_states if num_tokens < padding_size: final_hidden_states = final_hidden_states[:num_tokens] - elif self.dp_size > 1: + elif self.dp_size > 1 and not self.enable_prefill_optimizations: if fused_moe_state == FusedMoEState.NaiveMulticast: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b20ffa3b6d..e4afbb5611 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -334,6 +334,29 @@ def fused_experts_with_mc2( return hidden_states, shared_output +def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts): + num_tokens, _ = hidden_states.shape + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( + 1, 0).contiguous().view(-1)) + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + global_expert_tokens = global_expert_tokens.to(torch.int32) + quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) + return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales + + # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all( @@ -358,50 +381,54 @@ def fused_experts_with_all2all( num_tokens, _ = hidden_states.shape num_experts = w1.shape[0] - device = hidden_states.device if expert_map is not None: global_num_experts = len(expert_map) + global_redundant_expert_num - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - - hidden_states = hidden_states[sorted_idx] - group_list_type = 0 + if hasattr(torch_npu, "npu_moe_init_routing_quant"): + quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( + hidden_states, + expert_idx=topk_ids.to(torch.int32), + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_num_mode=2, + expert_tokens_before_capacity_flag=False, + quant_mode=1, + ) + else: + quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( + hidden_states, top_k, topk_ids, global_num_experts) + + gather_sizes = global_expert_tokens.new_empty( + global_expert_tokens.shape[0]) + dist.all_to_all_single(gather_sizes, global_expert_tokens) + + token_counts_combined = torch.stack( + [gather_sizes, global_expert_tokens], dim=0) + token_counts_combined = token_counts_combined.view( + 2, ep_group.world_size, -1).sum(dim=2) + token_counts_combined_cpu = token_counts_combined.to( + torch.device("cpu"), non_blocking=True).numpy() + all_tokens = gather_sizes.sum() + + gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), + quantized_tokens.shape[1]) + dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) + gather_size_list = token_counts_combined_cpu[1] + scatter_size_list = token_counts_combined_cpu[0] + + dist.all_to_all_single(gathered_tokens, quantized_tokens, + scatter_size_list, gather_size_list) + dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, + gather_size_list) + + hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( + gathered_tokens, + gather_sizes.view(ep_group.world_size, -1), + per_token_scales=dynamic_scale) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 else: row_idx_len = num_tokens * top_k row_idx = torch.arange(0, @@ -419,6 +446,7 @@ def fused_experts_with_all2all( expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 + dynamic_scale = None # `hidden_states` will be disposed in the `apply_mlp` function hidden_states = apply_mlp( @@ -428,14 +456,19 @@ def fused_experts_with_all2all( w2, w2_scale, expert_tokens, #16 + dynamic_scale=dynamic_scale, group_list_type=group_list_type) if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) + reordered_outputs = torch.index_select( + hidden_states, + dim=0, + # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU + index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) + + hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) + dist.all_to_all_single(hidden_states, reordered_outputs, + gather_size_list, scatter_size_list) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, @@ -444,8 +477,8 @@ def fused_experts_with_all2all( bias=None, scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) + export_for_source_row=None, + drop_pad_mode=2) else: # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available.