Skip to content

Commit 2ef3d98

Browse files
[0.9.1][Prefill Perf] add D2H & initRoutingQuantV2 (#2038)
### What this PR does / why we need it? 1.delete fused_experts_with_all2all_v2 method, combine prefill optimization points into alltoall method 2.Access to D2H & initRoutingQuantV2 and this pr will cherry-pick to main in #2110 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? use initRoutingQuantV2 with PTA poc package PTA poc package link https://pytorch-package.obs.cn-north-4.myhuaweicloud.com:443/pta/personal/cache/pytorch/v7.1.0-pytorch2.5.1-vllm/pytorchv7.1.0-pytorch2.5.1-vllm_3.11_aarch64.tar.gz --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com> Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
1 parent 92e6aa9 commit 2ef3d98

File tree

1 file changed

+82
-245
lines changed

1 file changed

+82
-245
lines changed

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 82 additions & 245 deletions
Original file line numberDiff line numberDiff line change
@@ -110,73 +110,6 @@ def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor],
110110
return hidden_states
111111

112112

113-
def gmm_expert(hidden_states: torch.Tensor,
114-
expert_tokens: List[int],
115-
w1: torch.Tensor,
116-
w2: torch.Tensor,
117-
w1_scale: torch.Tensor,
118-
w2_scale: torch.Tensor,
119-
dynamic_scale: Optional[torch.Tensor] = None,
120-
avg_tokens_per_expert: Optional[List[int]] = None,
121-
local_num_experts: int = 16) -> torch.Tensor:
122-
hidden_size = hidden_states.size(-1)
123-
124-
# Flatten input and dynamic_scale if needed
125-
if dynamic_scale is not None and dynamic_scale.dim() > 1:
126-
dynamic_scale = dynamic_scale.reshape(-1)
127-
hidden_states = hidden_states.view(-1, hidden_size)
128-
129-
# First grouped matmul (up-projection)
130-
mm1_output = torch_npu.npu_grouped_matmul(
131-
[hidden_states], [w1],
132-
group_list=expert_tokens,
133-
split_item=3,
134-
output_dtype=torch.int32,
135-
group_type=0,
136-
group_list_type=1,
137-
tuning_config=avg_tokens_per_expert)[0]
138-
139-
# Prepare quantization scale for dequant + activation
140-
quant_scale = torch.ones((local_num_experts, w1_scale.shape[-1] // 2),
141-
dtype=torch.float32,
142-
device=hidden_states.device)
143-
w1_scale = w1_scale.to(torch.float32)
144-
145-
# Apply dequant + swiglu activation
146-
intermediate_states, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
147-
x=mm1_output,
148-
weight_scale=w1_scale,
149-
activation_scale=dynamic_scale.squeeze(0)
150-
if dynamic_scale is not None else None,
151-
bias=None,
152-
quant_scale=quant_scale,
153-
quant_offset=None,
154-
group_index=expert_tokens,
155-
activate_left=True,
156-
quant_mode=1)
157-
158-
# Flatten again if necessary
159-
if dynamic_scale is not None and dynamic_scale.dim() > 1:
160-
intermediate_states = intermediate_states.view(
161-
-1, intermediate_states.size(-1))
162-
dynamic_scale = dynamic_scale.reshape(-1)
163-
164-
# Final grouped matmul (down-projection)
165-
output = torch_npu.npu_grouped_matmul(
166-
[intermediate_states], [w2],
167-
bias=None,
168-
scale=[w2_scale.to(torch.bfloat16)],
169-
per_token_scale=[dynamic_scale],
170-
group_list=expert_tokens,
171-
split_item=3,
172-
output_dtype=torch.bfloat16,
173-
group_type=0,
174-
group_list_type=1,
175-
tuning_config=avg_tokens_per_expert)[0]
176-
177-
return output
178-
179-
180113
def apply_mlp(hidden_states: torch.Tensor,
181114
w1: torch.Tensor,
182115
w1_scale: torch.Tensor,
@@ -525,6 +458,29 @@ def fused_prefill_experts_with_mc2(
525458
return hidden_states_outputs, shared_outputs, expert_token_nums, group_list_type
526459

527460

461+
def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts):
462+
num_tokens, _ = hidden_states.shape
463+
row_idx_len = num_tokens * top_k
464+
row_idx = (torch.arange(0,
465+
row_idx_len,
466+
dtype=torch.int32,
467+
device=hidden_states.device).view(
468+
top_k, -1).permute(1, 0).contiguous())
469+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
470+
hidden_states,
471+
row_idx=row_idx,
472+
expert_idx=topk_ids,
473+
active_num=num_tokens)
474+
475+
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
476+
1, 0).contiguous().view(-1))
477+
global_expert_tokens = torch.bincount(expanded_expert_idx,
478+
minlength=global_num_experts)
479+
global_expert_tokens = global_expert_tokens.to(torch.int32)
480+
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states)
481+
return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales
482+
483+
528484
# currently expert parallelism implemented with all2all
529485
# is under-optimized.
530486
def fused_experts_with_all2all(hidden_states: torch.Tensor,
@@ -549,53 +505,54 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
549505

550506
num_tokens, _ = hidden_states.shape
551507
num_experts = w1.shape[0]
552-
device = hidden_states.device
553508

554509
if expert_map is not None:
555510
global_num_experts = len(expert_map) + global_redundant_expert_num
556-
local_num_experts = global_num_experts // ep_group.world_size
557-
row_idx_len = num_tokens * top_k
558-
row_idx = (torch.arange(0,
559-
row_idx_len,
560-
dtype=torch.int32,
561-
device=device).view(top_k, -1).permute(
562-
1, 0).contiguous())
563-
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
564-
hidden_states,
565-
row_idx=row_idx,
566-
expert_idx=topk_ids,
567-
active_num=num_tokens)
568-
569-
global_expert_tokens = torch.bincount(expanded_expert_idx,
570-
minlength=global_num_experts)
571-
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
572-
-1).sum(-1)
573-
574-
gather_sizes = torch.empty_like(scatter_sizes)
575-
dist.all_to_all_single(gather_sizes,
576-
scatter_sizes,
577-
group=ep_group.device_group)
578-
scatter_size_list = scatter_sizes.cpu().tolist()
579-
gather_size_list = gather_sizes.cpu().tolist()
580-
581-
expanded_expert_idx = expanded_expert_idx % local_num_experts
582-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
583-
scatter_size_list,
584-
gather_size_list)
585-
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
586-
scatter_size_list,
587-
gather_size_list)
588-
589-
# Workaround: Convert to float so that sort runs on AI Core instead of slower AICPU
590-
sorted_local_expert_idx, sorted_idx = torch.sort(
591-
local_expert_idx.float())
592-
sorted_local_expert_idx = sorted_local_expert_idx.to(
593-
local_expert_idx.dtype)
594-
595-
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
596-
sorted_local_expert_idx, local_num_experts).to(torch.int64)
597-
598-
hidden_states = hidden_states[sorted_idx]
511+
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
512+
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
513+
hidden_states,
514+
expert_idx=topk_ids.to(torch.int32),
515+
active_num=0,
516+
expert_capacity=0,
517+
expert_num=global_num_experts,
518+
drop_pad_mode=0,
519+
expert_tokens_num_mode=2,
520+
expert_tokens_before_capacity_flag=False,
521+
quant_mode=1,
522+
)
523+
else:
524+
quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant(
525+
hidden_states, top_k, topk_ids, global_num_experts)
526+
527+
gather_sizes = global_expert_tokens.new_empty(
528+
global_expert_tokens.shape[0])
529+
dist.all_to_all_single(gather_sizes, global_expert_tokens)
530+
531+
token_counts_combined = torch.stack(
532+
[gather_sizes, global_expert_tokens], dim=0)
533+
token_counts_combined = token_counts_combined.view(
534+
2, ep_group.world_size, -1).sum(dim=2)
535+
token_counts_combined_cpu = token_counts_combined.to(
536+
torch.device("cpu"), non_blocking=True).numpy()
537+
all_tokens = gather_sizes.sum()
538+
539+
gathered_tokens = quantized_tokens.new_empty(all_tokens.item(),
540+
quantized_tokens.shape[1])
541+
dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0])
542+
gather_size_list = token_counts_combined_cpu[1]
543+
scatter_size_list = token_counts_combined_cpu[0]
544+
545+
dist.all_to_all_single(gathered_tokens, quantized_tokens,
546+
scatter_size_list, gather_size_list)
547+
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
548+
gather_size_list)
549+
550+
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
551+
gathered_tokens,
552+
gather_sizes.view(ep_group.world_size, -1),
553+
per_token_scales=dynamic_scale)
554+
expert_tokens = expert_tokens.to(torch.int64)
555+
group_list_type = 1
599556
else:
600557
row_idx_len = num_tokens * top_k
601558
row_idx = torch.arange(0,
@@ -612,7 +569,8 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
612569
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
613570
expanded_expert_idx, num_experts)
614571
expert_tokens = expert_tokens.to(torch.int64)
615-
group_list_type = 0
572+
group_list_type = 0
573+
dynamic_scale = None
616574

617575
# `hidden_states` will be disposed in the `apply_mlp` function
618576
hidden_states = apply_mlp(
@@ -622,17 +580,21 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
622580
w2,
623581
w2_scale,
624582
expert_tokens, #16
583+
dynamic_scale=dynamic_scale,
625584
group_list_type=group_list_type,
626585
w1_scale_bias=w1_scale_bias,
627586
w2_scale_bias=w2_scale_bias)
628587

629588
if expert_map is not None:
630-
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
631-
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
632-
hidden_states = hidden_states[resorted_idx]
633-
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
634-
gather_size_list,
635-
scatter_size_list)
589+
reordered_outputs = torch.index_select(
590+
hidden_states,
591+
dim=0,
592+
# Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
593+
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
594+
595+
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
596+
dist.all_to_all_single(hidden_states, reordered_outputs,
597+
gather_size_list, scatter_size_list)
636598

637599
final_hidden_states = torch_npu.npu_moe_finalize_routing(
638600
hidden_states,
@@ -641,8 +603,8 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
641603
bias=None,
642604
scales=topk_weights,
643605
expanded_src_to_dst_row=expanded_row_idx,
644-
export_for_source_row=topk_ids,
645-
)
606+
export_for_source_row=None,
607+
drop_pad_mode=2)
646608
else:
647609
# TODO: Reorder device memory 2 times here, replace the current
648610
# implementation here when suitable operators become available.
@@ -660,115 +622,6 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
660622
return final_hidden_states, expert_tokens, group_list_type
661623

662624

663-
def fused_experts_with_all2all_v2(
664-
hidden_states: torch.Tensor,
665-
top_k: int,
666-
topk_ids: torch.Tensor,
667-
topk_weight: torch.Tensor,
668-
w1: torch.Tensor,
669-
w2: torch.Tensor,
670-
w1_scale: torch.Tensor,
671-
w2_scale: torch.Tensor,
672-
expert_map: torch.Tensor = None,
673-
ep_group: GroupCoordinator = None,
674-
log2phy: Optional[torch.Tensor] = None,
675-
global_redundant_expert_num: int = 0) -> torch.Tensor:
676-
if log2phy is not None:
677-
topk_ids = log2phy[topk_ids]
678-
679-
num_tokens, _ = hidden_states.shape
680-
topk_weight = topk_weight.to(hidden_states.dtype)
681-
global_num_experts = len(expert_map) + global_redundant_expert_num
682-
local_num_experts = global_num_experts // ep_group.world_size
683-
684-
# Step 1: Routing initialization
685-
row_idx = (torch.arange(0,
686-
num_tokens * top_k,
687-
dtype=torch.int32,
688-
device=hidden_states.device).view(
689-
top_k, -1).permute(1, 0).contiguous())
690-
691-
routed_tokens, expanded_row_idx, expert_indices = torch_npu.npu_moe_init_routing(
692-
hidden_states,
693-
row_idx=row_idx,
694-
expert_idx=topk_ids,
695-
active_num=num_tokens)
696-
697-
expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute(
698-
1, 0).contiguous().view(-1))
699-
700-
tokens_per_expert = torch.bincount(expert_indices,
701-
minlength=global_num_experts).to(
702-
torch.int32)
703-
704-
# Step 2: Quantize
705-
quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(routed_tokens)
706-
707-
# Step 3: All-to-All communication of token counts
708-
tokens_per_expert_recv = torch.empty_like(tokens_per_expert)
709-
dist.all_to_all_single(tokens_per_expert_recv, tokens_per_expert)
710-
711-
token_counts_combined = torch.stack(
712-
[tokens_per_expert_recv, tokens_per_expert], dim=0)
713-
token_counts_combined = token_counts_combined.view(2, ep_group.world_size,
714-
-1).sum(dim=2)
715-
716-
total_received_tokens = token_counts_combined[0].sum()
717-
input_splits = token_counts_combined[1].cpu().tolist()
718-
output_splits = token_counts_combined[0].cpu().tolist()
719-
720-
# Step 4: All-to-All token exchange
721-
gathered_tokens = quantized_tokens.new_empty(total_received_tokens.item(),
722-
quantized_tokens.shape[1])
723-
gathered_scales = token_scales.new_empty(total_received_tokens.item())
724-
725-
dist.all_to_all_single(gathered_tokens, quantized_tokens, output_splits,
726-
input_splits)
727-
dist.all_to_all_single(gathered_scales, token_scales, output_splits,
728-
input_splits)
729-
730-
# Step 5: Re-routing received tokens
731-
routed_tokens_by_expert, gathered_scales, inverse_indices, tokens_per_local_expert = torch_npu.npu_moe_re_routing(
732-
gathered_tokens,
733-
tokens_per_expert_recv.view(ep_group.world_size, -1),
734-
per_token_scales=gathered_scales)
735-
736-
expert_outputs = gmm_expert(routed_tokens_by_expert,
737-
expert_tokens=tokens_per_local_expert.to(
738-
torch.int64),
739-
w1=w1,
740-
w2=w2,
741-
w1_scale=w1_scale,
742-
w2_scale=w2_scale,
743-
dynamic_scale=gathered_scales,
744-
local_num_experts=local_num_experts)
745-
746-
# Step 6: Reorder outputs back to original token order
747-
reordered_outputs = torch.index_select(
748-
expert_outputs,
749-
dim=0,
750-
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
751-
752-
# Step 7: Final All-to-All to return outputs to original ranks
753-
final_gathered_tokens = reordered_outputs.new_empty(
754-
*quantized_tokens.shape)
755-
dist.all_to_all_single(final_gathered_tokens, reordered_outputs,
756-
input_splits, output_splits)
757-
758-
# Step 8: Final routing aggregation
759-
final_hidden_states = torch_npu.npu_moe_finalize_routing(
760-
final_gathered_tokens,
761-
skip1=None,
762-
skip2=None,
763-
bias=None,
764-
scales=topk_weight.to(final_gathered_tokens.dtype),
765-
expanded_src_to_dst_row=expanded_row_idx,
766-
export_for_source_row=None,
767-
drop_pad_mode=2)
768-
769-
return final_hidden_states, tokens_per_local_expert.to(torch.int64), 1
770-
771-
772625
def fused_experts(hidden_states: torch.Tensor,
773626
w1: torch.Tensor,
774627
w1_scale: torch.Tensor,
@@ -977,7 +830,6 @@ def __init__(self):
977830
ascend_config = get_ascend_config()
978831
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
979832
self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout
980-
self.enable_prefill_optimizations = ascend_config.enable_prefill_optimizations
981833

982834
try:
983835
device_group = get_mc2_group().device_group
@@ -1159,21 +1011,6 @@ def apply(
11591011
topk_ids=topk_ids,
11601012
top_k=top_k,
11611013
expert_map=expert_map)
1162-
elif self.enable_prefill_optimizations:
1163-
return fused_experts_with_all2all_v2(
1164-
hidden_states=x,
1165-
top_k=top_k,
1166-
topk_ids=topk_ids,
1167-
topk_weight=topk_weights,
1168-
w1=layer.w13_weight,
1169-
w2=layer.w2_weight,
1170-
w1_scale=layer.w13_weight_scale,
1171-
w2_scale=layer.w2_weight_scale,
1172-
expert_map=expert_map,
1173-
ep_group=self.ep_group,
1174-
log2phy=log2phy,
1175-
global_redundant_expert_num=global_redundant_expert_num,
1176-
)
11771014
else:
11781015
# The current implementation of deepseek moe splits hidden_states
11791016
# according to tp_size before they are feed into fused_moe module.

0 commit comments

Comments
 (0)