Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,36 @@ def __init__(

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_scores, router_logits = self.router(hidden_states) # transformers>=4.54

router_scores, router_logits = self.router(hidden_states)
out = self.shared_expert(hidden_states)

for expert_index in range(self.num_experts):
# find expert scores
expert_score = router_scores[:, expert_index].unsqueeze(-1)
top_token_mask = expert_score[:, 0] > 0
_, router_indices = torch.topk(router_logits, self.top_k, dim=1)
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=self.num_experts
).permute(2, 1, 0) # (num_experts, top_k, batch_size * sequence_length)

for i in range(self.num_experts):
# fetch relevant token indices for this expert
token_idx = torch.where(expert_mask[i].squeeze(0))

# llama4 applies scores before expert relu
expert_in = hidden_states * expert_score
# Original Llama4 definition - apply score to hidden states
# before applying to expert this results in NaNs during calibration
# routed_in = hidden_states * router_scores[:, i].reshape(-1, 1)

# calibrate experts
if self.calibrate_all_experts:
expert_out = self.experts[expert_index](expert_in)[top_token_mask]
# all tokens for this expert
expert_out = self.experts[i](hidden_states)[token_idx]
else:
expert_out = self.experts[expert_index](expert_in[top_token_mask])

# accumulate output
out[top_token_mask] += expert_out
# only relevant tokens for this expert
expert_out = self.experts[i](hidden_states[token_idx])

if len(token_idx) > 0:
# Deviation from original Llama4 definition to avoid NaNs
# NaNs during calibration
weighted_output = expert_out * router_scores[:, i][token_idx].reshape(
-1, 1
)
out[token_idx] += weighted_output

return out, router_logits

Expand Down
8 changes: 4 additions & 4 deletions tests/llmcompressor/modeling/test_calib_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def test_calib_llama4_module():
module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=True)
with calibration_forward_context(module):
out, router_logits = module(sample)
assert torch.nn.functional.mse_loss(true_out, out) < 1e-10
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1

module = SequentialLlama4TextMoe(original, config, calibrate_all_experts=False)
with calibration_forward_context(module):
out, router_logits = module(sample)
assert torch.nn.functional.mse_loss(true_out, out) < 1e-10
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 1e-10
assert torch.nn.functional.mse_loss(true_out, out) < 0.1
assert torch.nn.functional.mse_loss(true_router_logits, router_logits) < 0.1