@@ -1038,6 +1038,9 @@ def __init__(
10381038 expert_mapping : Optional [list [tuple [str , str , int , str ]]] = None ,
10391039 ):
10401040 super ().__init__ ()
1041+
1042+ self .se_stream = torch .cuda .Stream ()
1043+
10411044 if params_dtype is None :
10421045 params_dtype = torch .get_default_dtype ()
10431046 self .params_dtype = params_dtype
@@ -2110,7 +2113,11 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21102113 not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
21112114 and self .shared_experts is not None
21122115 ):
2113- shared_output = self .shared_experts (staged_hidden_states )
2116+ current_stream = torch .cuda .current_stream ()
2117+ self .se_stream .wait_stream (current_stream )
2118+ with torch .cuda .stream (self .se_stream ):
2119+ shared_output = self .shared_experts (staged_hidden_states )
2120+
21142121 else :
21152122 shared_output = None
21162123
@@ -2140,6 +2147,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21402147 if shared_output is not None :
21412148 assert not isinstance (final_hidden_states , tuple )
21422149 assert self .shared_experts is not None
2150+
2151+ current_stream .wait_stream (self .se_stream )
2152+
21432153 final_hidden_states = (
21442154 shared_output ,
21452155 final_hidden_states ,
@@ -2234,7 +2244,10 @@ def forward_impl(
22342244 not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
22352245 and self .shared_experts is not None
22362246 ):
2237- shared_output = self .shared_experts (hidden_states )
2247+ current_stream = torch .cuda .current_stream ()
2248+ self .se_stream .wait_stream (current_stream )
2249+ with torch .cuda .stream (self .se_stream ):
2250+ shared_output = self .shared_experts (hidden_states )
22382251 else :
22392252 shared_output = None
22402253
@@ -2278,6 +2291,9 @@ def forward_impl(
22782291 if shared_output is not None :
22792292 assert not isinstance (final_hidden_states , tuple )
22802293 assert self .shared_experts is not None
2294+
2295+ current_stream .wait_stream (self .se_stream )
2296+
22812297 final_hidden_states = (
22822298 shared_output ,
22832299 final_hidden_states ,
0 commit comments