1111 dispatch_fused_experts_func , dispatch_topk_func ,
1212 torch_vllm_inplace_fused_experts , torch_vllm_outplace_fused_experts ,
1313 vllm_topk_softmax )
14+ from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
15+ is_rocm_aiter_moe_enabled )
1416from vllm .model_executor .layers .layernorm import (
1517 RMSNorm , dispatch_cuda_rmsnorm_func , fused_add_rms_norm , rms_norm ,
1618 rocm_aiter_fused_add_rms_norm , rocm_aiter_rms_norm )
@@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str):
100102def test_topk_dispatch (use_rocm_aiter : str , monkeypatch ):
101103 monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , use_rocm_aiter )
102104 topk_func = dispatch_topk_func ()
103-
105+ is_rocm_aiter_moe_enabled . cache_clear ()
104106 if current_platform .is_rocm () and int (use_rocm_aiter ):
105107 from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
106108 rocm_aiter_topk_softmax )
107-
108109 assert topk_func == rocm_aiter_topk_softmax
109110 else :
110111 assert topk_func == vllm_topk_softmax
@@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
116117 monkeypatch ):
117118
118119 monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , use_rocm_aiter )
120+ is_rocm_aiter_moe_enabled .cache_clear ()
119121 fused_experts_func = dispatch_fused_experts_func (inplace )
120122 if current_platform .is_rocm () and int (use_rocm_aiter ):
121123 from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
122124 rocm_aiter_fused_experts )
123-
124125 assert fused_experts_func == rocm_aiter_fused_experts
125126 elif inplace :
126127 assert fused_experts_func == torch_vllm_inplace_fused_experts
0 commit comments