Skip to content

Commit

Permalink
[Bugfix] Fix compute_logits in Jamba (#6093)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jul 3, 2024
1 parent f1c7813 commit 7cd2ebb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def _prepare_mamba_cache(self):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

Expand Down

0 comments on commit 7cd2ebb

Please sign in to comment.