1515from .base_device_communicator import All2AllManagerBase , Cache
1616
1717if has_flashinfer_all2all ():
18- from flashinfer .comm import Mapping
19- from flashinfer .comm .mnnvl import MnnvlConfig
20- from flashinfer .comm .trtllm_alltoall import MnnvlMoe
18+ from flashinfer .comm import Mapping # type: ignore[import-not-found]
19+ from flashinfer .comm .mnnvl import MnnvlConfig # type: ignore[import-not-found]
20+ from flashinfer .comm .trtllm_alltoall import (
21+ MnnvlMoe , # type: ignore[import-not-found]
22+ )
2123
2224logger = init_logger (__name__ )
2325
@@ -65,6 +67,7 @@ def dispatch(
6567 ) -> tuple [torch .Tensor , torch .Tensor ]:
6668 sp_size = self .tp_group .world_size if is_sequence_parallel else 1
6769 dp_metadata = get_forward_context ().dp_metadata
70+ assert dp_metadata is not None
6871 cu_tokens_across_sp_cpu = dp_metadata .cu_tokens_across_sp (sp_size )
6972
7073 hidden_states = self .naive_multicast (
@@ -81,6 +84,7 @@ def combine(
8184 ep_rank = self .rank if is_sequence_parallel else self .dp_rank
8285
8386 dp_metadata = get_forward_context ().dp_metadata
87+ assert dp_metadata is not None
8488 sp_size = self .tp_group .world_size if is_sequence_parallel else 1
8589 cu_tokens_across_sp_cpu = dp_metadata .cu_tokens_across_sp (sp_size )
8690
@@ -113,7 +117,10 @@ def dispatch(
113117 """
114118 Gather hidden_states and router_logits from all dp ranks.
115119 """
116- sizes = get_forward_context ().dp_metadata .get_chunk_sizes_across_dp_rank ()
120+ dp_metadata = get_forward_context ().dp_metadata
121+ assert dp_metadata is not None
122+ sizes = dp_metadata .get_chunk_sizes_across_dp_rank ()
123+ assert sizes is not None
117124
118125 dist_group = get_ep_group () if is_sequence_parallel else get_dp_group ()
119126 assert sizes [dist_group .rank_in_group ] == hidden_states .shape [0 ]
@@ -130,7 +137,10 @@ def combine(
130137 """
131138 Reduce-scatter hidden_states across all dp ranks.
132139 """
133- sizes = get_forward_context ().dp_metadata .get_chunk_sizes_across_dp_rank ()
140+ dp_metadata = get_forward_context ().dp_metadata
141+ assert dp_metadata is not None
142+ sizes = dp_metadata .get_chunk_sizes_across_dp_rank ()
143+ assert sizes is not None
134144
135145 dist_group = get_ep_group () if is_sequence_parallel else get_dp_group ()
136146 hidden_states = dist_group .reduce_scatterv (hidden_states , dim = 0 , sizes = sizes )
@@ -155,7 +165,7 @@ def __init__(self, cpu_group):
155165 if self .internode :
156166 # inter-node communication needs nvshmem,
157167 # intra-node communication uses p2p mapping directly
158- from pplx_kernels .nvshmem import (
168+ from pplx_kernels .nvshmem import ( # type: ignore[import-not-found]
159169 nvshmem_alloc_empty_unique_id ,
160170 nvshmem_get_unique_id ,
161171 nvshmem_init ,
@@ -182,7 +192,7 @@ def __init__(self, cpu_group):
182192 self .handle_cache = Cache ()
183193
184194 def get_handle (self , kwargs ):
185- import pplx_kernels as pplx
195+ import pplx_kernels as pplx # type: ignore[import-not-found]
186196
187197 return self .handle_cache .get_or_create (
188198 kwargs ,
@@ -208,7 +218,9 @@ def destroy(self):
208218 handle .destroy ()
209219
210220 if self .internode :
211- from pplx_kernels .nvshmem import nvshmem_finalize
221+ from pplx_kernels .nvshmem import (
222+ nvshmem_finalize , # type: ignore[import-not-found]
223+ )
212224
213225 logger .debug ("PPLX NVSHMEM finalize" )
214226 nvshmem_finalize ()
@@ -288,7 +300,7 @@ def get_handle(self, kwargs):
288300 "args are computed in the Manager itself."
289301 )
290302
291- import deep_ep
303+ import deep_ep # type: ignore[import-not-found]
292304
293305 buffer_kwargs = self ._make_all2all_kwargs ()
294306 logger .debug ("DeepEP all2all args %s" , buffer_kwargs )
@@ -298,7 +310,7 @@ def get_handle(self, kwargs):
298310 return handle
299311
300312 def set_num_sms (self , num_sms : int ):
301- import deep_ep
313+ import deep_ep # type: ignore[import-not-found]
302314
303315 # Right now the buffers are sized for only what the kernels were
304316 # created with. So we can only reduce the number of SMS used
@@ -332,7 +344,7 @@ def _make_all2all_kwargs(
332344 num_global_experts: Number of experts in the model.
333345 num_local_experts: Number of experts in an EP rank.
334346 """
335- import deep_ep
347+ import deep_ep # type: ignore[import-not-found]
336348
337349 # Defaults for internode and intranode are taken from DeepEP tests.
338350 num_nvl_bytes = envs .VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
@@ -358,7 +370,7 @@ def get_handle(self, kwargs):
358370 The kwargs for DeepEPLLAll2AllManager is dictated by
359371 _make_all2all_kwargs.
360372 """
361- import deep_ep
373+ import deep_ep # type: ignore[import-not-found]
362374
363375 buffer_kwargs = self ._make_all2all_kwargs (** kwargs )
364376 logger .debug ("DeepEP all2all args %s" , buffer_kwargs )
0 commit comments