@@ -154,47 +154,6 @@ def __init__(
154154 CustomDeepseekDBOMoE .top_k = config .num_experts_per_tok
155155 self .config = config
156156
157- def forward (
158- self ,
159- hidden_states : torch .Tensor ,
160- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
161- forward_context = get_forward_context ()
162- if attn_metadata is None :
163- attn_metadata = forward_context .attn_metadata
164-
165- # when profile runs, force experts to load balanced tokens
166- # to avoid high memory consumption on a single rank.
167- enable_force_load_balance = forward_context .in_profile_run
168-
169- is_prefill = forward_context .with_prefill
170- # If this node is kv_consumer, we force the moe always runs in decode path to make sure
171- # the behaviour aligned between dummy_run and normal model_execute.
172- if self .kv_consumer :
173- is_prefill = False
174-
175- # router_logits: (num_tokens, n_experts)
176- router_logits , _ = self .gate (hidden_states )
177-
178- experts_hidden_states = self .experts (
179- hidden_states = hidden_states ,
180- router_logits = router_logits ,
181- is_prefill = is_prefill ,
182- top_k = CustomDeepseekDBOMoE .top_k ,
183- enable_force_load_balance = enable_force_load_balance ,
184- shared_experts = self .shared_experts )
185-
186- shared_experts_hidden = experts_hidden_states [1 ]
187- if not (self .shared_experts .down_proj .reduce_results
188- and self .shared_experts .down_proj .tp_size > 1 ):
189- shared_experts_hidden = tensor_model_parallel_all_reduce (
190- shared_experts_hidden )
191-
192- hidden_states = (
193- experts_hidden_states [0 ] * self .routed_scaling_factor +
194- shared_experts_hidden )
195-
196- return hidden_states
197-
198157 # ----------------------------------------- TBO-related --------------------------------------------
199158 def _forward_ms_op_shared_expert (
200159 self ,
0 commit comments