11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ import os
45from abc import abstractmethod
56from collections .abc import Callable , Iterable
67from contextlib import nullcontext
@@ -1042,6 +1043,20 @@ def __init__(
10421043 expert_mapping : list [tuple [str , str , int , str ]] | None = None ,
10431044 ):
10441045 super ().__init__ ()
1046+
1047+ # TODO: Allow disabling of the separate shared experts stream for
1048+ # debug purposes. Remove this after more extensive testings with
1049+ # TP/DP and other execution modes
1050+ disable_shared_experts_stream = os .environ .get (
1051+ "DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM" , None
1052+ )
1053+
1054+ if disable_shared_experts_stream is not None :
1055+ logger .info_once ("Disabling MoE shared_experts cuda stream" )
1056+ self .shared_experts_stream = None
1057+ else :
1058+ self .shared_experts_stream = torch .cuda .Stream ()
1059+
10451060 if params_dtype is None :
10461061 params_dtype = torch .get_default_dtype ()
10471062 self .params_dtype = params_dtype
@@ -1265,6 +1280,10 @@ def __init__(
12651280 def shared_experts (self ) -> torch .nn .Module | None :
12661281 return None
12671282
1283+ @property
1284+ def gate (self ) -> torch .nn .Module | None :
1285+ return None
1286+
12681287 @property
12691288 def tp_size (self ):
12701289 return self .moe_parallel_config .tp_size
@@ -2054,6 +2073,7 @@ def forward_impl_chunked(
20542073 self ,
20552074 full_hidden_states : torch .Tensor ,
20562075 full_router_logits : torch .Tensor ,
2076+ has_separate_shared_experts : bool ,
20572077 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
20582078 assert self .batched_hidden_states is not None
20592079 assert self .batched_router_logits is not None
@@ -2102,11 +2122,24 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21022122
21032123 # If there are shared experts but we are not using a modular kernel,
21042124 # the shared experts must be called here
2105- if (
2106- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2107- and self .shared_experts is not None
2108- ):
2109- shared_output = self .shared_experts (staged_hidden_states )
2125+ if has_separate_shared_experts :
2126+ assert self .shared_experts is not None
2127+
2128+ if self .shared_experts_stream is not None :
2129+ # For chunked, we start the shared experts stream here
2130+ # (Note that no concurrency with the router/gate)
2131+ current_stream = torch .cuda .current_stream ()
2132+ self .shared_experts_stream .wait_stream (current_stream )
2133+
2134+ with torch .cuda .stream (self .shared_experts_stream ):
2135+ # Note that staged_hidden_states clone() is necessary
2136+ # here to avoid conflict with the main stream
2137+ shared_output = self .shared_experts (
2138+ staged_hidden_states .clone ()
2139+ )
2140+ else :
2141+ shared_output = self .shared_experts (staged_hidden_states )
2142+
21102143 else :
21112144 shared_output = None
21122145
@@ -2133,9 +2166,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21332166 logical_replica_count = self .logical_replica_count ,
21342167 )
21352168
2136- if shared_output is not None :
2169+ if has_separate_shared_experts :
21372170 assert not isinstance (final_hidden_states , tuple )
21382171 assert self .shared_experts is not None
2172+
2173+ # Here we finish the shared experts stream
2174+ if self .shared_experts_stream is not None :
2175+ current_stream .wait_stream (self .shared_experts_stream )
2176+
21392177 final_hidden_states = (
21402178 shared_output ,
21412179 final_hidden_states ,
@@ -2205,20 +2243,52 @@ def forward_impl(
22052243
22062244 self .ensure_moe_quant_config ()
22072245
2208- if self .use_dp_chunking :
2209- return self .forward_impl_chunked (hidden_states , router_logits )
2246+ has_separate_shared_experts = (
2247+ not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2248+ and self .shared_experts is not None
2249+ )
2250+
2251+ use_chunked_impl = self .use_dp_chunking
2252+
2253+ if (
2254+ has_separate_shared_experts
2255+ and not use_chunked_impl
2256+ and self .shared_experts_stream is not None
2257+ ):
2258+ # Start the separate shared experts stream here since we want
2259+ # to run in parallel with the router/gate (next op below)
2260+ current_stream = torch .cuda .current_stream ()
2261+ self .shared_experts_stream .wait_stream (current_stream )
2262+
2263+ # If router/gate provided, then apply it here.
2264+ # (Note: This code runs only when "overlapped mode" is on to allow
2265+ # parallel execution of shared experts with the FusedMoE via
2266+ # separate cuda stream)
2267+ if self .gate is not None :
2268+ router_logits , _ = self .gate (hidden_states )
2269+
2270+ if use_chunked_impl :
2271+ return self .forward_impl_chunked (
2272+ hidden_states , router_logits , has_separate_shared_experts
2273+ )
22102274
22112275 do_naive_dispatch_combine : bool = (
22122276 self .dp_size > 1 and not self .quant_method .using_modular_kernel
22132277 )
22142278
22152279 # If there are shared experts but we are not using a modular kernel, the
22162280 # shared experts must be called here
2217- if (
2218- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2219- and self .shared_experts is not None
2220- ):
2221- shared_output = self .shared_experts (hidden_states )
2281+ if has_separate_shared_experts :
2282+ assert self .shared_experts is not None
2283+
2284+ if self .shared_experts_stream is not None :
2285+ # Run shared experts in parallel on a separate stream
2286+ with torch .cuda .stream (self .shared_experts_stream ):
2287+ # Note that hidden_states clone() is necessary here to avoid
2288+ # conflict with the main stream
2289+ shared_output = self .shared_experts (hidden_states .clone ())
2290+ else :
2291+ shared_output = self .shared_experts (hidden_states )
22222292 else :
22232293 shared_output = None
22242294
@@ -2259,9 +2329,14 @@ def forward_impl(
22592329 logical_replica_count = self .logical_replica_count ,
22602330 )
22612331
2262- if shared_output is not None :
2332+ if has_separate_shared_experts :
22632333 assert not isinstance (final_hidden_states , tuple )
22642334 assert self .shared_experts is not None
2335+
2336+ # Wait for the parallel shared experts stream to finish here
2337+ if self .shared_experts_stream is not None :
2338+ current_stream .wait_stream (self .shared_experts_stream )
2339+
22652340 final_hidden_states = (
22662341 shared_output ,
22672342 final_hidden_states ,
0 commit comments