|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
| 4 | +import os |
4 | 5 | from abc import abstractmethod |
5 | 6 | from collections.abc import Callable, Iterable |
6 | 7 | from contextlib import nullcontext |
@@ -1079,6 +1080,20 @@ def __init__( |
1079 | 1080 | n_shared_experts: int | None = None, |
1080 | 1081 | ): |
1081 | 1082 | super().__init__() |
| 1083 | + |
| 1084 | + # TODO: Allow disabling of the separate shared experts stream for |
| 1085 | + # debug purposes. Remove this after more extensive testings with |
| 1086 | + # TP/DP and other execution modes |
| 1087 | + disable_shared_experts_stream = os.environ.get( |
| 1088 | + "DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None |
| 1089 | + ) |
| 1090 | + |
| 1091 | + if disable_shared_experts_stream is not None: |
| 1092 | + logger.info_once("Disabling MoE shared_experts cuda stream") |
| 1093 | + self.shared_experts_stream = None |
| 1094 | + else: |
| 1095 | + self.shared_experts_stream = torch.cuda.Stream() |
| 1096 | + |
1082 | 1097 | if params_dtype is None: |
1083 | 1098 | params_dtype = torch.get_default_dtype() |
1084 | 1099 | self.params_dtype = params_dtype |
@@ -1328,6 +1343,10 @@ def __init__( |
1328 | 1343 | def shared_experts(self) -> torch.nn.Module | None: |
1329 | 1344 | return None |
1330 | 1345 |
|
| 1346 | + @property |
| 1347 | + def gate(self) -> torch.nn.Module | None: |
| 1348 | + return None |
| 1349 | + |
1331 | 1350 | @property |
1332 | 1351 | def tp_size(self): |
1333 | 1352 | return self.moe_parallel_config.tp_size |
@@ -2150,6 +2169,7 @@ def forward_impl_chunked( |
2150 | 2169 | self, |
2151 | 2170 | full_hidden_states: torch.Tensor, |
2152 | 2171 | full_router_logits: torch.Tensor, |
| 2172 | + has_separate_shared_experts: bool, |
2153 | 2173 | ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
2154 | 2174 | assert self.batched_hidden_states is not None |
2155 | 2175 | assert self.batched_router_logits is not None |
@@ -2198,11 +2218,24 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): |
2198 | 2218 |
|
2199 | 2219 | # If there are shared experts but we are not using a modular kernel, |
2200 | 2220 | # the shared experts must be called here |
2201 | | - if ( |
2202 | | - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
2203 | | - and self.shared_experts is not None |
2204 | | - ): |
2205 | | - shared_output = self.shared_experts(staged_hidden_states) |
| 2221 | + if has_separate_shared_experts: |
| 2222 | + assert self.shared_experts is not None |
| 2223 | + |
| 2224 | + if self.shared_experts_stream is not None: |
| 2225 | + # For chunked, we start the shared experts stream here |
| 2226 | + # (Note that no concurrency with the router/gate) |
| 2227 | + current_stream = torch.cuda.current_stream() |
| 2228 | + self.shared_experts_stream.wait_stream(current_stream) |
| 2229 | + |
| 2230 | + with torch.cuda.stream(self.shared_experts_stream): |
| 2231 | + # Note that staged_hidden_states clone() is necessary |
| 2232 | + # here to avoid conflict with the main stream |
| 2233 | + shared_output = self.shared_experts( |
| 2234 | + staged_hidden_states.clone() |
| 2235 | + ) |
| 2236 | + else: |
| 2237 | + shared_output = self.shared_experts(staged_hidden_states) |
| 2238 | + |
2206 | 2239 | else: |
2207 | 2240 | shared_output = None |
2208 | 2241 |
|
@@ -2231,9 +2264,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): |
2231 | 2264 | logical_replica_count=self.logical_replica_count, |
2232 | 2265 | ) |
2233 | 2266 |
|
2234 | | - if shared_output is not None: |
| 2267 | + if has_separate_shared_experts: |
2235 | 2268 | assert not isinstance(final_hidden_states, tuple) |
2236 | 2269 | assert self.shared_experts is not None |
| 2270 | + |
| 2271 | + # Here we finish the shared experts stream |
| 2272 | + if self.shared_experts_stream is not None: |
| 2273 | + current_stream.wait_stream(self.shared_experts_stream) |
| 2274 | + |
2237 | 2275 | final_hidden_states = ( |
2238 | 2276 | shared_output, |
2239 | 2277 | final_hidden_states, |
@@ -2303,20 +2341,52 @@ def forward_impl( |
2303 | 2341 |
|
2304 | 2342 | self.ensure_moe_quant_config() |
2305 | 2343 |
|
2306 | | - if self.use_dp_chunking: |
2307 | | - return self.forward_impl_chunked(hidden_states, router_logits) |
| 2344 | + has_separate_shared_experts = ( |
| 2345 | + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
| 2346 | + and self.shared_experts is not None |
| 2347 | + ) |
| 2348 | + |
| 2349 | + use_chunked_impl = self.use_dp_chunking |
| 2350 | + |
| 2351 | + if ( |
| 2352 | + has_separate_shared_experts |
| 2353 | + and not use_chunked_impl |
| 2354 | + and self.shared_experts_stream is not None |
| 2355 | + ): |
| 2356 | + # Start the separate shared experts stream here since we want |
| 2357 | + # to run in parallel with the router/gate (next op below) |
| 2358 | + current_stream = torch.cuda.current_stream() |
| 2359 | + self.shared_experts_stream.wait_stream(current_stream) |
| 2360 | + |
| 2361 | + # If router/gate provided, then apply it here. |
| 2362 | + # (Note: This code runs only when "overlapped mode" is on to allow |
| 2363 | + # parallel execution of shared experts with the FusedMoE via |
| 2364 | + # separate cuda stream) |
| 2365 | + if self.gate is not None: |
| 2366 | + router_logits, _ = self.gate(hidden_states) |
| 2367 | + |
| 2368 | + if use_chunked_impl: |
| 2369 | + return self.forward_impl_chunked( |
| 2370 | + hidden_states, router_logits, has_separate_shared_experts |
| 2371 | + ) |
2308 | 2372 |
|
2309 | 2373 | do_naive_dispatch_combine: bool = ( |
2310 | 2374 | self.dp_size > 1 and not self.quant_method.using_modular_kernel |
2311 | 2375 | ) |
2312 | 2376 |
|
2313 | 2377 | # If there are shared experts but we are not using a modular kernel, the |
2314 | 2378 | # shared experts must be called here |
2315 | | - if ( |
2316 | | - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) |
2317 | | - and self.shared_experts is not None |
2318 | | - ): |
2319 | | - shared_output = self.shared_experts(hidden_states) |
| 2379 | + if has_separate_shared_experts: |
| 2380 | + assert self.shared_experts is not None |
| 2381 | + |
| 2382 | + if self.shared_experts_stream is not None: |
| 2383 | + # Run shared experts in parallel on a separate stream |
| 2384 | + with torch.cuda.stream(self.shared_experts_stream): |
| 2385 | + # Note that hidden_states clone() is necessary here to avoid |
| 2386 | + # conflict with the main stream |
| 2387 | + shared_output = self.shared_experts(hidden_states.clone()) |
| 2388 | + else: |
| 2389 | + shared_output = self.shared_experts(hidden_states) |
2320 | 2390 | else: |
2321 | 2391 | shared_output = None |
2322 | 2392 |
|
@@ -2359,9 +2429,14 @@ def forward_impl( |
2359 | 2429 | logical_replica_count=self.logical_replica_count, |
2360 | 2430 | ) |
2361 | 2431 |
|
2362 | | - if shared_output is not None: |
| 2432 | + if has_separate_shared_experts: |
2363 | 2433 | assert not isinstance(final_hidden_states, tuple) |
2364 | 2434 | assert self.shared_experts is not None |
| 2435 | + |
| 2436 | + # Wait for the parallel shared experts stream to finish here |
| 2437 | + if self.shared_experts_stream is not None: |
| 2438 | + current_stream.wait_stream(self.shared_experts_stream) |
| 2439 | + |
2365 | 2440 | final_hidden_states = ( |
2366 | 2441 | shared_output, |
2367 | 2442 | final_hidden_states, |
|
0 commit comments