Skip to content

Commit 2833dcf

Browse files
pooyadavoodiweilong.yu
authored andcommitted
[Bugfix] Use runner_type instead of task in GritLM (vllm-project#11144)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
1 parent ff20e57 commit 2833dcf

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

tests/models/embedding/language/test_gritlm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_find_array(monkeypatch):
3535
from vllm.model_executor.models.gritlm import GritLMPooler
3636

3737
# Create an LLM object to get the model config.
38-
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
38+
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
3939
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)
4040

4141
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
@@ -55,7 +55,7 @@ def server_embedding():
5555
with pytest.MonkeyPatch.context() as mp:
5656
mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
5757

58-
args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)]
58+
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
5959
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
6060
yield remote_server
6161

@@ -141,7 +141,7 @@ def test_gritlm_offline_embedding(monkeypatch):
141141

142142
queries, q_instruction, documents, d_instruction = get_test_data()
143143

144-
llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN)
144+
llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN)
145145

146146
d_rep = run_llm_encode(
147147
llm,

vllm/model_executor/models/gritlm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ def __init__(
203203
) -> None:
204204
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
205205

206-
self.task = vllm_config.model_config.task
206+
self.runner_type = vllm_config.model_config.runner_type
207207

208208
self._pooler = GritLMPooler(vllm_config.model_config)
209209

210210
for layer in self.model.layers:
211-
if self.task == "embedding" and hasattr(layer, "self_attn"):
211+
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
212212
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
213213
"GritLM embedding is only supported by XFormers backend, "
214214
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
@@ -222,8 +222,8 @@ def forward(
222222
**kwargs,
223223
) -> Union[torch.Tensor, IntermediateTensors]:
224224

225-
# Change attention to non-causal for embedding task.
226-
if self.task == "embedding":
225+
# Change attention to non-causal for pooling tasks.
226+
if self.runner_type == "pooling":
227227
assert attn_metadata.prefill_metadata.attn_bias is None
228228
attn_metadata.prefill_metadata.attn_bias = [
229229
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)

0 commit comments

Comments
 (0)