@@ -110,73 +110,6 @@ def apply_mlp_decode(hidden_states_wrapper: List[torch.Tensor],
110110 return hidden_states
111111
112112
113- def gmm_expert (hidden_states : torch .Tensor ,
114- expert_tokens : List [int ],
115- w1 : torch .Tensor ,
116- w2 : torch .Tensor ,
117- w1_scale : torch .Tensor ,
118- w2_scale : torch .Tensor ,
119- dynamic_scale : Optional [torch .Tensor ] = None ,
120- avg_tokens_per_expert : Optional [List [int ]] = None ,
121- local_num_experts : int = 16 ) -> torch .Tensor :
122- hidden_size = hidden_states .size (- 1 )
123-
124- # Flatten input and dynamic_scale if needed
125- if dynamic_scale is not None and dynamic_scale .dim () > 1 :
126- dynamic_scale = dynamic_scale .reshape (- 1 )
127- hidden_states = hidden_states .view (- 1 , hidden_size )
128-
129- # First grouped matmul (up-projection)
130- mm1_output = torch_npu .npu_grouped_matmul (
131- [hidden_states ], [w1 ],
132- group_list = expert_tokens ,
133- split_item = 3 ,
134- output_dtype = torch .int32 ,
135- group_type = 0 ,
136- group_list_type = 1 ,
137- tuning_config = avg_tokens_per_expert )[0 ]
138-
139- # Prepare quantization scale for dequant + activation
140- quant_scale = torch .ones ((local_num_experts , w1_scale .shape [- 1 ] // 2 ),
141- dtype = torch .float32 ,
142- device = hidden_states .device )
143- w1_scale = w1_scale .to (torch .float32 )
144-
145- # Apply dequant + swiglu activation
146- intermediate_states , dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
147- x = mm1_output ,
148- weight_scale = w1_scale ,
149- activation_scale = dynamic_scale .squeeze (0 )
150- if dynamic_scale is not None else None ,
151- bias = None ,
152- quant_scale = quant_scale ,
153- quant_offset = None ,
154- group_index = expert_tokens ,
155- activate_left = True ,
156- quant_mode = 1 )
157-
158- # Flatten again if necessary
159- if dynamic_scale is not None and dynamic_scale .dim () > 1 :
160- intermediate_states = intermediate_states .view (
161- - 1 , intermediate_states .size (- 1 ))
162- dynamic_scale = dynamic_scale .reshape (- 1 )
163-
164- # Final grouped matmul (down-projection)
165- output = torch_npu .npu_grouped_matmul (
166- [intermediate_states ], [w2 ],
167- bias = None ,
168- scale = [w2_scale .to (torch .bfloat16 )],
169- per_token_scale = [dynamic_scale ],
170- group_list = expert_tokens ,
171- split_item = 3 ,
172- output_dtype = torch .bfloat16 ,
173- group_type = 0 ,
174- group_list_type = 1 ,
175- tuning_config = avg_tokens_per_expert )[0 ]
176-
177- return output
178-
179-
180113def apply_mlp (hidden_states : torch .Tensor ,
181114 w1 : torch .Tensor ,
182115 w1_scale : torch .Tensor ,
@@ -525,6 +458,29 @@ def fused_prefill_experts_with_mc2(
525458 return hidden_states_outputs , shared_outputs , expert_token_nums , group_list_type
526459
527460
461+ def init_routing_quant (hidden_states , top_k , topk_ids , global_num_experts ):
462+ num_tokens , _ = hidden_states .shape
463+ row_idx_len = num_tokens * top_k
464+ row_idx = (torch .arange (0 ,
465+ row_idx_len ,
466+ dtype = torch .int32 ,
467+ device = hidden_states .device ).view (
468+ top_k , - 1 ).permute (1 , 0 ).contiguous ())
469+ hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
470+ hidden_states ,
471+ row_idx = row_idx ,
472+ expert_idx = topk_ids ,
473+ active_num = num_tokens )
474+
475+ expanded_row_idx = (expanded_row_idx .view (top_k , - 1 ).permute (
476+ 1 , 0 ).contiguous ().view (- 1 ))
477+ global_expert_tokens = torch .bincount (expanded_expert_idx ,
478+ minlength = global_num_experts )
479+ global_expert_tokens = global_expert_tokens .to (torch .int32 )
480+ quantized_tokens , token_scales = torch_npu .npu_dynamic_quant (hidden_states )
481+ return quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales
482+
483+
528484# currently expert parallelism implemented with all2all
529485# is under-optimized.
530486def fused_experts_with_all2all (hidden_states : torch .Tensor ,
@@ -549,53 +505,54 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
549505
550506 num_tokens , _ = hidden_states .shape
551507 num_experts = w1 .shape [0 ]
552- device = hidden_states .device
553508
554509 if expert_map is not None :
555510 global_num_experts = len (expert_map ) + global_redundant_expert_num
556- local_num_experts = global_num_experts // ep_group .world_size
557- row_idx_len = num_tokens * top_k
558- row_idx = (torch .arange (0 ,
559- row_idx_len ,
560- dtype = torch .int32 ,
561- device = device ).view (top_k , - 1 ).permute (
562- 1 , 0 ).contiguous ())
563- hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
564- hidden_states ,
565- row_idx = row_idx ,
566- expert_idx = topk_ids ,
567- active_num = num_tokens )
568-
569- global_expert_tokens = torch .bincount (expanded_expert_idx ,
570- minlength = global_num_experts )
571- scatter_sizes = global_expert_tokens .view (ep_group .world_size ,
572- - 1 ).sum (- 1 )
573-
574- gather_sizes = torch .empty_like (scatter_sizes )
575- dist .all_to_all_single (gather_sizes ,
576- scatter_sizes ,
577- group = ep_group .device_group )
578- scatter_size_list = scatter_sizes .cpu ().tolist ()
579- gather_size_list = gather_sizes .cpu ().tolist ()
580-
581- expanded_expert_idx = expanded_expert_idx % local_num_experts
582- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
583- scatter_size_list ,
584- gather_size_list )
585- local_expert_idx = ep_group .all_to_all (expanded_expert_idx , 0 , 0 ,
586- scatter_size_list ,
587- gather_size_list )
588-
589- # Workaround: Convert to float so that sort runs on AI Core instead of slower AICPU
590- sorted_local_expert_idx , sorted_idx = torch .sort (
591- local_expert_idx .float ())
592- sorted_local_expert_idx = sorted_local_expert_idx .to (
593- local_expert_idx .dtype )
594-
595- expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
596- sorted_local_expert_idx , local_num_experts ).to (torch .int64 )
597-
598- hidden_states = hidden_states [sorted_idx ]
511+ if hasattr (torch_npu , "npu_moe_init_routing_quant" ):
512+ quantized_tokens , expanded_row_idx , global_expert_tokens , _ , token_scales = torch_npu .npu_moe_init_routing_quant (
513+ hidden_states ,
514+ expert_idx = topk_ids .to (torch .int32 ),
515+ active_num = 0 ,
516+ expert_capacity = 0 ,
517+ expert_num = global_num_experts ,
518+ drop_pad_mode = 0 ,
519+ expert_tokens_num_mode = 2 ,
520+ expert_tokens_before_capacity_flag = False ,
521+ quant_mode = 1 ,
522+ )
523+ else :
524+ quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales = init_routing_quant (
525+ hidden_states , top_k , topk_ids , global_num_experts )
526+
527+ gather_sizes = global_expert_tokens .new_empty (
528+ global_expert_tokens .shape [0 ])
529+ dist .all_to_all_single (gather_sizes , global_expert_tokens )
530+
531+ token_counts_combined = torch .stack (
532+ [gather_sizes , global_expert_tokens ], dim = 0 )
533+ token_counts_combined = token_counts_combined .view (
534+ 2 , ep_group .world_size , - 1 ).sum (dim = 2 )
535+ token_counts_combined_cpu = token_counts_combined .to (
536+ torch .device ("cpu" ), non_blocking = True ).numpy ()
537+ all_tokens = gather_sizes .sum ()
538+
539+ gathered_tokens = quantized_tokens .new_empty (all_tokens .item (),
540+ quantized_tokens .shape [1 ])
541+ dynamic_scale = token_scales .new_empty (gathered_tokens .shape [0 ])
542+ gather_size_list = token_counts_combined_cpu [1 ]
543+ scatter_size_list = token_counts_combined_cpu [0 ]
544+
545+ dist .all_to_all_single (gathered_tokens , quantized_tokens ,
546+ scatter_size_list , gather_size_list )
547+ dist .all_to_all_single (dynamic_scale , token_scales , scatter_size_list ,
548+ gather_size_list )
549+
550+ hidden_states , dynamic_scale , inverse_indices , expert_tokens = torch_npu .npu_moe_re_routing (
551+ gathered_tokens ,
552+ gather_sizes .view (ep_group .world_size , - 1 ),
553+ per_token_scales = dynamic_scale )
554+ expert_tokens = expert_tokens .to (torch .int64 )
555+ group_list_type = 1
599556 else :
600557 row_idx_len = num_tokens * top_k
601558 row_idx = torch .arange (0 ,
@@ -612,7 +569,8 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
612569 expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
613570 expanded_expert_idx , num_experts )
614571 expert_tokens = expert_tokens .to (torch .int64 )
615- group_list_type = 0
572+ group_list_type = 0
573+ dynamic_scale = None
616574
617575 # `hidden_states` will be disposed in the `apply_mlp` function
618576 hidden_states = apply_mlp (
@@ -622,17 +580,21 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
622580 w2 ,
623581 w2_scale ,
624582 expert_tokens , #16
583+ dynamic_scale = dynamic_scale ,
625584 group_list_type = group_list_type ,
626585 w1_scale_bias = w1_scale_bias ,
627586 w2_scale_bias = w2_scale_bias )
628587
629588 if expert_map is not None :
630- # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
631- resorted_idx = torch .argsort (sorted_idx .float ()).to (sorted_idx .dtype )
632- hidden_states = hidden_states [resorted_idx ]
633- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
634- gather_size_list ,
635- scatter_size_list )
589+ reordered_outputs = torch .index_select (
590+ hidden_states ,
591+ dim = 0 ,
592+ # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
593+ index = inverse_indices .to (torch .float32 ).argsort ().to (torch .int32 ))
594+
595+ hidden_states = reordered_outputs .new_empty (* quantized_tokens .shape )
596+ dist .all_to_all_single (hidden_states , reordered_outputs ,
597+ gather_size_list , scatter_size_list )
636598
637599 final_hidden_states = torch_npu .npu_moe_finalize_routing (
638600 hidden_states ,
@@ -641,8 +603,8 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
641603 bias = None ,
642604 scales = topk_weights ,
643605 expanded_src_to_dst_row = expanded_row_idx ,
644- export_for_source_row = topk_ids ,
645- )
606+ export_for_source_row = None ,
607+ drop_pad_mode = 2 )
646608 else :
647609 # TODO: Reorder device memory 2 times here, replace the current
648610 # implementation here when suitable operators become available.
@@ -660,115 +622,6 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
660622 return final_hidden_states , expert_tokens , group_list_type
661623
662624
663- def fused_experts_with_all2all_v2 (
664- hidden_states : torch .Tensor ,
665- top_k : int ,
666- topk_ids : torch .Tensor ,
667- topk_weight : torch .Tensor ,
668- w1 : torch .Tensor ,
669- w2 : torch .Tensor ,
670- w1_scale : torch .Tensor ,
671- w2_scale : torch .Tensor ,
672- expert_map : torch .Tensor = None ,
673- ep_group : GroupCoordinator = None ,
674- log2phy : Optional [torch .Tensor ] = None ,
675- global_redundant_expert_num : int = 0 ) -> torch .Tensor :
676- if log2phy is not None :
677- topk_ids = log2phy [topk_ids ]
678-
679- num_tokens , _ = hidden_states .shape
680- topk_weight = topk_weight .to (hidden_states .dtype )
681- global_num_experts = len (expert_map ) + global_redundant_expert_num
682- local_num_experts = global_num_experts // ep_group .world_size
683-
684- # Step 1: Routing initialization
685- row_idx = (torch .arange (0 ,
686- num_tokens * top_k ,
687- dtype = torch .int32 ,
688- device = hidden_states .device ).view (
689- top_k , - 1 ).permute (1 , 0 ).contiguous ())
690-
691- routed_tokens , expanded_row_idx , expert_indices = torch_npu .npu_moe_init_routing (
692- hidden_states ,
693- row_idx = row_idx ,
694- expert_idx = topk_ids ,
695- active_num = num_tokens )
696-
697- expanded_row_idx = (expanded_row_idx .view (top_k , - 1 ).permute (
698- 1 , 0 ).contiguous ().view (- 1 ))
699-
700- tokens_per_expert = torch .bincount (expert_indices ,
701- minlength = global_num_experts ).to (
702- torch .int32 )
703-
704- # Step 2: Quantize
705- quantized_tokens , token_scales = torch_npu .npu_dynamic_quant (routed_tokens )
706-
707- # Step 3: All-to-All communication of token counts
708- tokens_per_expert_recv = torch .empty_like (tokens_per_expert )
709- dist .all_to_all_single (tokens_per_expert_recv , tokens_per_expert )
710-
711- token_counts_combined = torch .stack (
712- [tokens_per_expert_recv , tokens_per_expert ], dim = 0 )
713- token_counts_combined = token_counts_combined .view (2 , ep_group .world_size ,
714- - 1 ).sum (dim = 2 )
715-
716- total_received_tokens = token_counts_combined [0 ].sum ()
717- input_splits = token_counts_combined [1 ].cpu ().tolist ()
718- output_splits = token_counts_combined [0 ].cpu ().tolist ()
719-
720- # Step 4: All-to-All token exchange
721- gathered_tokens = quantized_tokens .new_empty (total_received_tokens .item (),
722- quantized_tokens .shape [1 ])
723- gathered_scales = token_scales .new_empty (total_received_tokens .item ())
724-
725- dist .all_to_all_single (gathered_tokens , quantized_tokens , output_splits ,
726- input_splits )
727- dist .all_to_all_single (gathered_scales , token_scales , output_splits ,
728- input_splits )
729-
730- # Step 5: Re-routing received tokens
731- routed_tokens_by_expert , gathered_scales , inverse_indices , tokens_per_local_expert = torch_npu .npu_moe_re_routing (
732- gathered_tokens ,
733- tokens_per_expert_recv .view (ep_group .world_size , - 1 ),
734- per_token_scales = gathered_scales )
735-
736- expert_outputs = gmm_expert (routed_tokens_by_expert ,
737- expert_tokens = tokens_per_local_expert .to (
738- torch .int64 ),
739- w1 = w1 ,
740- w2 = w2 ,
741- w1_scale = w1_scale ,
742- w2_scale = w2_scale ,
743- dynamic_scale = gathered_scales ,
744- local_num_experts = local_num_experts )
745-
746- # Step 6: Reorder outputs back to original token order
747- reordered_outputs = torch .index_select (
748- expert_outputs ,
749- dim = 0 ,
750- index = inverse_indices .to (torch .float32 ).argsort ().to (torch .int32 ))
751-
752- # Step 7: Final All-to-All to return outputs to original ranks
753- final_gathered_tokens = reordered_outputs .new_empty (
754- * quantized_tokens .shape )
755- dist .all_to_all_single (final_gathered_tokens , reordered_outputs ,
756- input_splits , output_splits )
757-
758- # Step 8: Final routing aggregation
759- final_hidden_states = torch_npu .npu_moe_finalize_routing (
760- final_gathered_tokens ,
761- skip1 = None ,
762- skip2 = None ,
763- bias = None ,
764- scales = topk_weight .to (final_gathered_tokens .dtype ),
765- expanded_src_to_dst_row = expanded_row_idx ,
766- export_for_source_row = None ,
767- drop_pad_mode = 2 )
768-
769- return final_hidden_states , tokens_per_local_expert .to (torch .int64 ), 1
770-
771-
772625def fused_experts (hidden_states : torch .Tensor ,
773626 w1 : torch .Tensor ,
774627 w1_scale : torch .Tensor ,
@@ -977,7 +830,6 @@ def __init__(self):
977830 ascend_config = get_ascend_config ()
978831 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
979832 self .enable_weight_nz_layout = ascend_config .enable_weight_nz_layout
980- self .enable_prefill_optimizations = ascend_config .enable_prefill_optimizations
981833
982834 try :
983835 device_group = get_mc2_group ().device_group
@@ -1159,21 +1011,6 @@ def apply(
11591011 topk_ids = topk_ids ,
11601012 top_k = top_k ,
11611013 expert_map = expert_map )
1162- elif self .enable_prefill_optimizations :
1163- return fused_experts_with_all2all_v2 (
1164- hidden_states = x ,
1165- top_k = top_k ,
1166- topk_ids = topk_ids ,
1167- topk_weight = topk_weights ,
1168- w1 = layer .w13_weight ,
1169- w2 = layer .w2_weight ,
1170- w1_scale = layer .w13_weight_scale ,
1171- w2_scale = layer .w2_weight_scale ,
1172- expert_map = expert_map ,
1173- ep_group = self .ep_group ,
1174- log2phy = log2phy ,
1175- global_redundant_expert_num = global_redundant_expert_num ,
1176- )
11771014 else :
11781015 # The current implementation of deepseek moe splits hidden_states
11791016 # according to tp_size before they are feed into fused_moe module.
0 commit comments