2020from typing import Any , Dict , List , Optional , Tuple , Union
2121
2222import torch
23- import tvm_ffi
23+ import paddle
24+
25+ with paddle .compat .use_torch_proxy_guard (enable = False ):
26+ import tvm_ffi
2427
2528from ..artifacts import ArtifactPath , MetaInfoHash
2629from ..autotuner import (
@@ -463,11 +466,15 @@ def __init__(
463466 use_mxfp8_act_scaling ,
464467 )
465468
469+ def paddle_dtype_to_tvm_ffi_dtype (dtype : paddle .dtype ):
470+ dtype_str = str (dtype ).split ("." , 1 )[- 1 ]
471+ return tvm_ffi .dtype (dtype_str )
472+
466473 if instance_key not in MoERunner .runner_dict :
467474 MoERunner .runner_dict [instance_key ] = module .init (
468- x_dtype ,
469- weight_dtype ,
470- output_dtype ,
475+ paddle_dtype_to_tvm_ffi_dtype ( x_dtype ) ,
476+ paddle_dtype_to_tvm_ffi_dtype ( weight_dtype ) ,
477+ paddle_dtype_to_tvm_ffi_dtype ( output_dtype ) ,
471478 use_deepseek_fp8_block_scale ,
472479 use_w4_group_scaling ,
473480 use_mxfp8_act_scaling ,
@@ -565,7 +572,8 @@ def cutlass_fused_moe(
565572 enable_pdl : Optional [bool ] = None ,
566573 ) -> List [torch .Tensor ]:
567574 if enable_pdl is None :
568- enable_pdl = device_support_pdl (input .device )
575+ # enable_pdl = device_support_pdl(input.device)
576+ enable_pdl = device_support_pdl (input .place )
569577 tuner = AutoTuner .get ()
570578 MoERunner .refine_tuning_config (tune_max_num_tokens )
571579
@@ -623,17 +631,22 @@ def cutlass_fused_moe(
623631 else moe_runner .fused_moe_runner .run_moe
624632 )
625633 num_active_experts_per_node = torch .empty (
626- (1 ,), dtype = torch .int32 , device = input .device
634+ # (1,), dtype=torch.int32, device=input.device
635+ (1 ,),
636+ dtype = torch .int32 ,
637+ device = input .place ,
627638 )
628639 experts_to_token_score = torch .empty (
629640 (fc2_expert_weights .shape [0 ], input .shape [0 ]),
630641 dtype = torch .float32 ,
631- device = input .device ,
642+ # device=input.device,
643+ device = input .place ,
632644 )
633645 active_expert_global_ids = torch .empty (
634646 (fc2_expert_weights .shape [0 ],),
635647 dtype = torch .int32 ,
636- device = input .device ,
648+ # device=input.device,
649+ device = input .place ,
637650 )
638651 min_latency_output = (
639652 [
@@ -897,7 +910,8 @@ def cutlass_fused_moe(
897910 raise NotImplementedError ("min latency mode not yet implemented for Blackwell." )
898911
899912 if enable_pdl is None :
900- enable_pdl = device_support_pdl (input .device )
913+ # enable_pdl = device_support_pdl(input.device)
914+ enable_pdl = device_support_pdl (input .place )
901915
902916 num_rows = input .shape [0 ]
903917 if min_latency_mode :
@@ -906,10 +920,16 @@ def cutlass_fused_moe(
906920 output_shape = (num_rows , hidden_size )
907921
908922 if output is None :
909- output = torch .empty (output_shape , dtype = output_dtype , device = input .device )
923+ # output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
924+ output = torch .empty (output_shape , dtype = output_dtype , device = input .place )
910925 else :
911926 check_shape_dtype_device (
912- output , output_shape , output_dtype , input .device , "output"
927+ # output, output_shape, output_dtype, input.device, "output"
928+ output ,
929+ output_shape ,
930+ output_dtype ,
931+ input .place ,
932+ "output" ,
913933 )
914934
915935 major , minor = torch .cuda .get_device_capability ()
0 commit comments