From 20958800220b1ca8b689bb61b5512f82260d781e Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:32:35 -0700 Subject: [PATCH] [Bugfix] Fix `compute_logits` in Jamba (#6093) Signed-off-by: Alvant --- vllm/model_executor/models/jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index c485d3779d9a6..bf330c7770d12 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -876,7 +876,7 @@ def _prepare_mamba_cache(self): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits