@@ -1042,6 +1042,9 @@ def __init__(
10421042 expert_mapping : list [tuple [str , str , int , str ]] | None = None ,
10431043 ):
10441044 super ().__init__ ()
1045+
1046+ self .shared_experts_stream = torch .cuda .Stream ()
1047+
10451048 if params_dtype is None :
10461049 params_dtype = torch .get_default_dtype ()
10471050 self .params_dtype = params_dtype
@@ -1265,6 +1268,10 @@ def __init__(
12651268 def shared_experts (self ) -> torch .nn .Module | None :
12661269 return None
12671270
1271+ @property
1272+ def gate (self ) -> Optional [torch .nn .Module ]:
1273+ return None
1274+
12681275 @property
12691276 def tp_size (self ):
12701277 return self .moe_parallel_config .tp_size
@@ -2054,6 +2061,7 @@ def forward_impl_chunked(
20542061 self ,
20552062 full_hidden_states : torch .Tensor ,
20562063 full_router_logits : torch .Tensor ,
2064+ has_separate_shared_experts : bool ,
20572065 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
20582066 assert self .batched_hidden_states is not None
20592067 assert self .batched_router_logits is not None
@@ -2102,11 +2110,19 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21022110
21032111 # If there are shared experts but we are not using a modular kernel,
21042112 # the shared experts must be called here
2105- if (
2106- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2107- and self .shared_experts is not None
2108- ):
2109- shared_output = self .shared_experts (staged_hidden_states )
2113+ if has_separate_shared_experts :
2114+ assert self .shared_experts is not None
2115+
2116+ # For chunked, we start the shared experts stream here
2117+ # (Note that no concurrency with the router/gate)
2118+ current_stream = torch .cuda .current_stream ()
2119+ self .shared_experts_stream .wait_stream (current_stream )
2120+
2121+ with torch .cuda .stream (self .shared_experts_stream ):
2122+ # Note that staged_hidden_states clone() is necessary
2123+ # here to avoid conflict with the main stream
2124+ shared_output = self .shared_experts (staged_hidden_states .clone ())
2125+
21102126 else :
21112127 shared_output = None
21122128
@@ -2133,9 +2149,13 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21332149 logical_replica_count = self .logical_replica_count ,
21342150 )
21352151
2136- if shared_output is not None :
2152+ if has_separate_shared_experts :
21372153 assert not isinstance (final_hidden_states , tuple )
21382154 assert self .shared_experts is not None
2155+
2156+ # Here we finish the shared experts stream
2157+ current_stream .wait_stream (self .shared_experts_stream )
2158+
21392159 final_hidden_states = (
21402160 shared_output ,
21412161 final_hidden_states ,
@@ -2205,20 +2225,42 @@ def forward_impl(
22052225
22062226 self .ensure_moe_quant_config ()
22072227
2208- if self .use_dp_chunking :
2209- return self .forward_impl_chunked (hidden_states , router_logits )
2228+ has_separate_shared_experts = (
2229+ not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2230+ and self .shared_experts is not None
2231+ )
2232+
2233+ use_chunked_impl = self .use_dp_chunking
2234+
2235+ if has_separate_shared_experts and not use_chunked_impl :
2236+ # Start the separate shared experts stream here since we want
2237+ # to run in parallel with the router/gate (next op below)
2238+ current_stream = torch .cuda .current_stream ()
2239+ self .shared_experts_stream .wait_stream (current_stream )
2240+
2241+ # If router/gate provided, then apply it here
2242+ if self .gate is not None :
2243+ router_logits , _ = self .gate (hidden_states )
2244+
2245+ if use_chunked_impl :
2246+ return self .forward_impl_chunked (
2247+ hidden_states , router_logits , has_separate_shared_experts
2248+ )
22102249
22112250 do_naive_dispatch_combine : bool = (
22122251 self .dp_size > 1 and not self .quant_method .using_modular_kernel
22132252 )
22142253
22152254 # If there are shared experts but we are not using a modular kernel, the
22162255 # shared experts must be called here
2217- if (
2218- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2219- and self .shared_experts is not None
2220- ):
2221- shared_output = self .shared_experts (hidden_states )
2256+ if has_separate_shared_experts :
2257+ assert self .shared_experts is not None
2258+
2259+ # Run shared experts in parallel on a separate stream
2260+ with torch .cuda .stream (self .shared_experts_stream ):
2261+ # Note that hidden_states clone() is necessary here to avoid
2262+ # conflict with the main stream
2263+ shared_output = self .shared_experts (hidden_states .clone ())
22222264 else :
22232265 shared_output = None
22242266
@@ -2259,9 +2301,13 @@ def forward_impl(
22592301 logical_replica_count = self .logical_replica_count ,
22602302 )
22612303
2262- if shared_output is not None :
2304+ if has_separate_shared_experts :
22632305 assert not isinstance (final_hidden_states , tuple )
22642306 assert self .shared_experts is not None
2307+
2308+ # Wait for the parallel shared experts stream to finish here
2309+ current_stream .wait_stream (self .shared_experts_stream )
2310+
22652311 final_hidden_states = (
22662312 shared_output ,
22672313 final_hidden_states ,
0 commit comments