@@ -1899,6 +1899,15 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
18991899 staged_hidden_states .copy_ (hidden_states , non_blocking = True )
19001900 staged_router_logits .copy_ (router_logits , non_blocking = True )
19011901
1902+ # If there are shared experts but we are not using a modular kernel,
1903+ # the shared experts must be called here
1904+ if (not isinstance (self .quant_method .fused_experts ,
1905+ FusedMoEModularKernel )
1906+ and self .shared_experts is not None ):
1907+ shared_output = self .shared_experts (staged_hidden_states )
1908+ else :
1909+ shared_output = None
1910+
19021911 # Matrix multiply.
19031912 final_hidden_states = self .quant_method .apply (
19041913 layer = self ,
@@ -1922,8 +1931,13 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
19221931 logical_replica_count = self .logical_replica_count ,
19231932 )
19241933
1925- assert self .shared_experts is None or isinstance (
1926- final_hidden_states , tuple )
1934+ if shared_output is not None :
1935+ assert not isinstance (final_hidden_states , tuple )
1936+ assert self .shared_experts is not None
1937+ final_hidden_states = (
1938+ shared_output ,
1939+ final_hidden_states ,
1940+ )
19271941
19281942 if self .zero_expert_num is not None and self .zero_expert_num > 0 :
19291943 assert isinstance (final_hidden_states , tuple )
0 commit comments