diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 3f7ab5761..9145c66a6 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -48,26 +48,36 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_scores, router_logits = self.router(hidden_states) # transformers>=4.54 - + router_scores, router_logits = self.router(hidden_states) out = self.shared_expert(hidden_states) - for expert_index in range(self.num_experts): - # find expert scores - expert_score = router_scores[:, expert_index].unsqueeze(-1) - top_token_mask = expert_score[:, 0] > 0 + _, router_indices = torch.topk(router_logits, self.top_k, dim=1) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ).permute(2, 1, 0) # (num_experts, top_k, batch_size * sequence_length) + + for i in range(self.num_experts): + # fetch relevant token indices for this expert + token_idx = torch.where(expert_mask[i].squeeze(0)) - # llama4 applies scores before expert relu - expert_in = hidden_states * expert_score + # Original Llama4 definition - apply score to hidden states + # before applying to expert this results in NaNs during calibration + # routed_in = hidden_states * router_scores[:, i].reshape(-1, 1) - # calibrate experts if self.calibrate_all_experts: - expert_out = self.experts[expert_index](expert_in)[top_token_mask] + # all tokens for this expert + expert_out = self.experts[i](hidden_states)[token_idx] else: - expert_out = self.experts[expert_index](expert_in[top_token_mask]) - - # accumulate output - out[top_token_mask] += expert_out + # only relevant tokens for this expert + expert_out = self.experts[i](hidden_states[token_idx]) + + if len(token_idx) > 0: + # Deviation from original Llama4 definition to avoid NaNs + # NaNs during calibration + weighted_output = expert_out * router_scores[:, i][token_idx].reshape( + -1, 1 + ) + out[token_idx] += weighted_output return out, router_logits diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 78fb4ee6d..90361bd80 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -85,11 +85,11 @@ def test_calib_llama4_module(): module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=True) with calibration_forward_context(module): out, router_logits = module(sample) - assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 - assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10 + assert torch.nn.functional.mse_loss(true_out, out) < 0.1 + assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1 module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=False) with calibration_forward_context(module): out, router_logits = module(sample) - assert torch.nn.functional.mse_loss(true_out, out) < 1e-10 - assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10 + assert torch.nn.functional.mse_loss(true_out, out) < 0.1 + assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1