|
57 | 57 | from vllm.platforms.interface import CpuArchEnum |
58 | 58 | from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up |
59 | 59 | from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe |
60 | | -from vllm.utils.torch_utils import direct_register_custom_op |
| 60 | +from vllm.utils.torch_utils import current_stream, direct_register_custom_op |
61 | 61 | from vllm.v1.worker.ubatching import dbo_current_ubatch_id |
62 | 62 |
|
63 | 63 | if current_platform.is_cuda_alike(): |
@@ -1082,6 +1082,17 @@ def __init__( |
1082 | 1082 | n_shared_experts: int | None = None, |
1083 | 1083 | ): |
1084 | 1084 | super().__init__() |
| 1085 | + |
| 1086 | + # Allow disabling of the separate shared experts stream for |
| 1087 | + # debug purposes. |
| 1088 | + # TODO: Remove this after more extensive testings with TP/DP |
| 1089 | + # and other execution modes |
| 1090 | + if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: |
| 1091 | + logger.info_once("Disabling MoE shared_experts cuda stream") |
| 1092 | + self.shared_experts_stream = None |
| 1093 | + else: |
| 1094 | + self.shared_experts_stream = torch.cuda.Stream() |
| 1095 | + |
1085 | 1096 | if params_dtype is None: |
1086 | 1097 | params_dtype = torch.get_default_dtype() |
1087 | 1098 | self.params_dtype = params_dtype |
@@ -1332,6 +1343,10 @@ def __init__( |
1332 | 1343 | def shared_experts(self) -> torch.nn.Module | None: |
1333 | 1344 | return None |
1334 | 1345 |
|
| 1346 | + @property |
| 1347 | + def gate(self) -> torch.nn.Module | None: |
| 1348 | + return None |
| 1349 | + |
1335 | 1350 | @property |
1336 | 1351 | def tp_size(self): |
1337 | 1352 | return self.moe_parallel_config.tp_size |
@@ -1390,6 +1405,11 @@ def use_dp_chunking(self) -> bool: |
1390 | 1405 | or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) |
1391 | 1406 | ) |
1392 | 1407 |
|
| 1408 | + @property |
| 1409 | + def is_internal_router(self) -> bool: |
| 1410 | + # By default, router/gate is called before FusedMoE forward pass |
| 1411 | + return False |
| 1412 | + |
1393 | 1413 | def update_expert_map(self): |
1394 | 1414 | # ep_size and ep_rank should already be updated |
1395 | 1415 | assert self.expert_map is not None |
@@ -2168,6 +2188,7 @@ def forward_impl_chunked( |
2168 | 2188 | self, |
2169 | 2189 | full_hidden_states: torch.Tensor, |
2170 | 2190 | full_router_logits: torch.Tensor, |
| 2191 | + has_separate_shared_experts: bool, |
2171 | 2192 | ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
2172 | 2193 | assert self.batched_hidden_states is not None |
2173 | 2194 | assert self.batched_router_logits is not None |
@@ -2216,11 +2237,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): |
2216 | 2237 |
|
2217 | 2238 | # If there are shared experts but we are not using a modular kernel, |
2218 | 2239 | # the shared experts must be called here |
2219 | | - if ( |
2220 | | - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
2221 | | - and self.shared_experts is not None |
2222 | | - ): |
2223 | | - shared_output = self.shared_experts(staged_hidden_states) |
| 2240 | + if has_separate_shared_experts: |
| 2241 | + assert self.shared_experts is not None |
| 2242 | + |
| 2243 | + if self.shared_experts_stream is not None: |
| 2244 | + # For chunked, we start the shared experts stream here |
| 2245 | + # (Note that no concurrency with the router/gate) |
| 2246 | + self.shared_experts_stream.wait_stream(current_stream()) |
| 2247 | + |
| 2248 | + with torch.cuda.stream(self.shared_experts_stream): |
| 2249 | + # Note that staged_hidden_states clone() is necessary |
| 2250 | + # here to avoid conflict with the main stream |
| 2251 | + shared_output = self.shared_experts( |
| 2252 | + staged_hidden_states.clone() |
| 2253 | + ) |
| 2254 | + else: |
| 2255 | + shared_output = self.shared_experts(staged_hidden_states) |
| 2256 | + |
2224 | 2257 | else: |
2225 | 2258 | shared_output = None |
2226 | 2259 |
|
@@ -2249,9 +2282,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): |
2249 | 2282 | logical_replica_count=self.logical_replica_count, |
2250 | 2283 | ) |
2251 | 2284 |
|
2252 | | - if shared_output is not None: |
| 2285 | + if has_separate_shared_experts: |
2253 | 2286 | assert not isinstance(final_hidden_states, tuple) |
2254 | 2287 | assert self.shared_experts is not None |
| 2288 | + |
| 2289 | + # Here we finish the shared experts stream |
| 2290 | + if self.shared_experts_stream is not None: |
| 2291 | + current_stream().wait_stream(self.shared_experts_stream) |
| 2292 | + |
2255 | 2293 | final_hidden_states = ( |
2256 | 2294 | shared_output, |
2257 | 2295 | final_hidden_states, |
@@ -2321,20 +2359,51 @@ def forward_impl( |
2321 | 2359 |
|
2322 | 2360 | self.ensure_moe_quant_config() |
2323 | 2361 |
|
2324 | | - if self.use_dp_chunking: |
2325 | | - return self.forward_impl_chunked(hidden_states, router_logits) |
| 2362 | + has_separate_shared_experts = ( |
| 2363 | + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
| 2364 | + and self.shared_experts is not None |
| 2365 | + ) |
| 2366 | + |
| 2367 | + use_chunked_impl = self.use_dp_chunking |
| 2368 | + |
| 2369 | + if ( |
| 2370 | + has_separate_shared_experts |
| 2371 | + and not use_chunked_impl |
| 2372 | + and self.shared_experts_stream is not None |
| 2373 | + ): |
| 2374 | + # Start the separate shared experts stream here since we want |
| 2375 | + # to run in parallel with the router/gate (next op below) |
| 2376 | + self.shared_experts_stream.wait_stream(current_stream()) |
| 2377 | + |
| 2378 | + # If router/gate provided, then apply it here. |
| 2379 | + # (Note: This code runs only when "overlapped mode" is on to allow |
| 2380 | + # parallel execution of shared experts with the FusedMoE via |
| 2381 | + # separate cuda stream) |
| 2382 | + if self.gate is not None: |
| 2383 | + router_logits, _ = self.gate(hidden_states) |
| 2384 | + |
| 2385 | + if use_chunked_impl: |
| 2386 | + return self.forward_impl_chunked( |
| 2387 | + hidden_states, router_logits, has_separate_shared_experts |
| 2388 | + ) |
2326 | 2389 |
|
2327 | 2390 | do_naive_dispatch_combine: bool = ( |
2328 | 2391 | self.dp_size > 1 and not self.quant_method.using_modular_kernel |
2329 | 2392 | ) |
2330 | 2393 |
|
2331 | 2394 | # If there are shared experts but we are not using a modular kernel, the |
2332 | 2395 | # shared experts must be called here |
2333 | | - if ( |
2334 | | - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
2335 | | - and self.shared_experts is not None |
2336 | | - ): |
2337 | | - shared_output = self.shared_experts(hidden_states) |
| 2396 | + if has_separate_shared_experts: |
| 2397 | + assert self.shared_experts is not None |
| 2398 | + |
| 2399 | + if self.shared_experts_stream is not None: |
| 2400 | + # Run shared experts in parallel on a separate stream |
| 2401 | + with torch.cuda.stream(self.shared_experts_stream): |
| 2402 | + # Note that hidden_states clone() is necessary here to avoid |
| 2403 | + # conflict with the main stream |
| 2404 | + shared_output = self.shared_experts(hidden_states.clone()) |
| 2405 | + else: |
| 2406 | + shared_output = self.shared_experts(hidden_states) |
2338 | 2407 | else: |
2339 | 2408 | shared_output = None |
2340 | 2409 |
|
@@ -2377,9 +2446,14 @@ def forward_impl( |
2377 | 2446 | logical_replica_count=self.logical_replica_count, |
2378 | 2447 | ) |
2379 | 2448 |
|
2380 | | - if shared_output is not None: |
| 2449 | + if has_separate_shared_experts: |
2381 | 2450 | assert not isinstance(final_hidden_states, tuple) |
2382 | 2451 | assert self.shared_experts is not None |
| 2452 | + |
| 2453 | + # Wait for the parallel shared experts stream to finish here |
| 2454 | + if self.shared_experts_stream is not None: |
| 2455 | + current_stream().wait_stream(self.shared_experts_stream) |
| 2456 | + |
2383 | 2457 | final_hidden_states = ( |
2384 | 2458 | shared_output, |
2385 | 2459 | final_hidden_states, |
|
0 commit comments