@@ -334,6 +334,29 @@ def fused_experts_with_mc2(
334334 return hidden_states , shared_output
335335
336336
337+ def init_routing_quant (hidden_states , top_k , topk_ids , global_num_experts ):
338+ num_tokens , _ = hidden_states .shape
339+ row_idx_len = num_tokens * top_k
340+ row_idx = (torch .arange (0 ,
341+ row_idx_len ,
342+ dtype = torch .int32 ,
343+ device = hidden_states .device ).view (
344+ top_k , - 1 ).permute (1 , 0 ).contiguous ())
345+ hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
346+ hidden_states ,
347+ row_idx = row_idx ,
348+ expert_idx = topk_ids ,
349+ active_num = num_tokens )
350+
351+ expanded_row_idx = (expanded_row_idx .view (top_k , - 1 ).permute (
352+ 1 , 0 ).contiguous ().view (- 1 ))
353+ global_expert_tokens = torch .bincount (expanded_expert_idx ,
354+ minlength = global_num_experts )
355+ global_expert_tokens = global_expert_tokens .to (torch .int32 )
356+ quantized_tokens , token_scales = torch_npu .npu_dynamic_quant (hidden_states )
357+ return quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales
358+
359+
337360# currently expert parallelism implemented with all2all
338361# is under-optimized.
339362def fused_experts_with_all2all (
@@ -358,50 +381,54 @@ def fused_experts_with_all2all(
358381
359382 num_tokens , _ = hidden_states .shape
360383 num_experts = w1 .shape [0 ]
361- device = hidden_states .device
362384
363385 if expert_map is not None :
364386 global_num_experts = len (expert_map ) + global_redundant_expert_num
365- local_num_experts = global_num_experts // ep_group .world_size
366- row_idx_len = num_tokens * top_k
367- row_idx = (torch .arange (0 ,
368- row_idx_len ,
369- dtype = torch .int32 ,
370- device = device ).view (top_k , - 1 ).permute (
371- 1 , 0 ).contiguous ())
372- hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
373- hidden_states ,
374- row_idx = row_idx ,
375- expert_idx = topk_ids ,
376- active_num = num_tokens )
377-
378- global_expert_tokens = torch .bincount (expanded_expert_idx ,
379- minlength = global_num_experts )
380- scatter_sizes = global_expert_tokens .view (ep_group .world_size ,
381- - 1 ).sum (- 1 )
382-
383- gather_sizes = torch .empty_like (scatter_sizes )
384- dist .all_to_all_single (gather_sizes ,
385- scatter_sizes ,
386- group = ep_group .device_group )
387- scatter_size_list = scatter_sizes .cpu ().tolist ()
388- gather_size_list = gather_sizes .cpu ().tolist ()
389-
390- expanded_expert_idx = expanded_expert_idx % local_num_experts
391- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
392- scatter_size_list ,
393- gather_size_list )
394- local_expert_idx = ep_group .all_to_all (expanded_expert_idx , 0 , 0 ,
395- scatter_size_list ,
396- gather_size_list )
397-
398- sorted_local_expert_idx , sorted_idx = torch .sort (local_expert_idx )
399-
400- expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
401- sorted_local_expert_idx , local_num_experts ).to (torch .int64 )
402-
403- hidden_states = hidden_states [sorted_idx ]
404- group_list_type = 0
387+ if hasattr (torch_npu , "npu_moe_init_routing_quant" ):
388+ quantized_tokens , expanded_row_idx , global_expert_tokens , _ , token_scales = torch_npu .npu_moe_init_routing_quant (
389+ hidden_states ,
390+ expert_idx = topk_ids .to (torch .int32 ),
391+ active_num = 0 ,
392+ expert_capacity = 0 ,
393+ expert_num = global_num_experts ,
394+ drop_pad_mode = 0 ,
395+ expert_tokens_num_mode = 2 ,
396+ expert_tokens_before_capacity_flag = False ,
397+ quant_mode = 1 ,
398+ )
399+ else :
400+ quantized_tokens , expanded_row_idx , global_expert_tokens , token_scales = init_routing_quant (
401+ hidden_states , top_k , topk_ids , global_num_experts )
402+
403+ gather_sizes = global_expert_tokens .new_empty (
404+ global_expert_tokens .shape [0 ])
405+ dist .all_to_all_single (gather_sizes , global_expert_tokens )
406+
407+ token_counts_combined = torch .stack (
408+ [gather_sizes , global_expert_tokens ], dim = 0 )
409+ token_counts_combined = token_counts_combined .view (
410+ 2 , ep_group .world_size , - 1 ).sum (dim = 2 )
411+ token_counts_combined_cpu = token_counts_combined .to (
412+ torch .device ("cpu" ), non_blocking = True ).numpy ()
413+ all_tokens = gather_sizes .sum ()
414+
415+ gathered_tokens = quantized_tokens .new_empty (all_tokens .item (),
416+ quantized_tokens .shape [1 ])
417+ dynamic_scale = token_scales .new_empty (gathered_tokens .shape [0 ])
418+ gather_size_list = token_counts_combined_cpu [1 ]
419+ scatter_size_list = token_counts_combined_cpu [0 ]
420+
421+ dist .all_to_all_single (gathered_tokens , quantized_tokens ,
422+ scatter_size_list , gather_size_list )
423+ dist .all_to_all_single (dynamic_scale , token_scales , scatter_size_list ,
424+ gather_size_list )
425+
426+ hidden_states , dynamic_scale , inverse_indices , expert_tokens = torch_npu .npu_moe_re_routing (
427+ gathered_tokens ,
428+ gather_sizes .view (ep_group .world_size , - 1 ),
429+ per_token_scales = dynamic_scale )
430+ expert_tokens = expert_tokens .to (torch .int64 )
431+ group_list_type = 1
405432 else :
406433 row_idx_len = num_tokens * top_k
407434 row_idx = torch .arange (0 ,
@@ -419,6 +446,7 @@ def fused_experts_with_all2all(
419446 expanded_expert_idx , num_experts )
420447 expert_tokens = expert_tokens .to (torch .int64 )
421448 group_list_type = 0
449+ dynamic_scale = None
422450
423451 # `hidden_states` will be disposed in the `apply_mlp` function
424452 hidden_states = apply_mlp (
@@ -428,14 +456,19 @@ def fused_experts_with_all2all(
428456 w2 ,
429457 w2_scale ,
430458 expert_tokens , #16
459+ dynamic_scale = dynamic_scale ,
431460 group_list_type = group_list_type )
432461
433462 if expert_map is not None :
434- resorted_idx = torch .argsort (sorted_idx )
435- hidden_states = hidden_states [resorted_idx ]
436- hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 ,
437- gather_size_list ,
438- scatter_size_list )
463+ reordered_outputs = torch .index_select (
464+ hidden_states ,
465+ dim = 0 ,
466+ # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU
467+ index = inverse_indices .to (torch .float32 ).argsort ().to (torch .int32 ))
468+
469+ hidden_states = reordered_outputs .new_empty (* quantized_tokens .shape )
470+ dist .all_to_all_single (hidden_states , reordered_outputs ,
471+ gather_size_list , scatter_size_list )
439472
440473 final_hidden_states = torch_npu .npu_moe_finalize_routing (
441474 hidden_states ,
@@ -444,8 +477,8 @@ def fused_experts_with_all2all(
444477 bias = None ,
445478 scales = topk_weights ,
446479 expanded_src_to_dst_row = expanded_row_idx ,
447- export_for_source_row = topk_ids ,
448- )
480+ export_for_source_row = None ,
481+ drop_pad_mode = 2 )
449482 else :
450483 # TODO: Reorder device memory 2 times here, replace the current
451484 # implementation here when suitable operators become available.
0 commit comments