Skip to content

Commit 112779f

Browse files
authored
Merge pull request vllm-project#7 from prashanth058/mlm-connector-support
add connector support
2 parents 5c156c9 + a69bde7 commit 112779f

File tree

11 files changed

+315
-123
lines changed

11 files changed

+315
-123
lines changed

tests/lora/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,21 @@ def qwen25vl_lora_files():
225225
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
226226

227227

228+
@pytest.fixture(scope="session")
229+
def qwen2vl_language_lora_files():
230+
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")
231+
232+
233+
@pytest.fixture(scope="session")
234+
def qwen2vl_vision_tower_connector_lora_files():
235+
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")
236+
237+
238+
@pytest.fixture(scope="session")
239+
def qwen2vl_vision_tower_lora_files():
240+
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")
241+
242+
228243
@pytest.fixture(scope="session")
229244
def tinyllama_lora_files():
230245
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

tests/lora/test_qwen2vl.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def run_test(
7979
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
8080
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
8181
generated_texts = [output.outputs[0].text.strip() for output in outputs]
82-
8382
# Validate outputs
8483
for generated, expected in zip(generated_texts, expected_outputs):
8584
assert expected.startswith(generated), (
@@ -130,6 +129,22 @@ def run_beam_search_test(
130129
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
131130
]
132131

132+
EXPECTED_OUTPUTS_LANGUAGE = [
133+
"A stop sign is shown in an Asian city, with buildings and a car in the "
134+
"background.",
135+
"The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.",
136+
]
137+
138+
EXPECTED_OUTPUTS_VISION = [
139+
"A stop sign in front of oriental buildings.",
140+
"A tree with pink flowers in front of it and a blue sky behind the flowers.",
141+
]
142+
143+
EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
144+
"A stop sign is located on the street of a Chinese neighborhood.",
145+
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
146+
]
147+
133148
# NOTE - beam search .text contains the whole text
134149
EXPECTED_BEAM_SEARCH_OUTPUTS = [
135150
[
@@ -190,3 +205,64 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
190205
# Test with different LoRA IDs
191206
for lora_id in [1, 2]:
192207
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
208+
209+
210+
@pytest.mark.xfail(
211+
current_platform.is_rocm(),
212+
reason="Qwen2-VL dependency xformers incompatible with ROCm",
213+
)
214+
def test_qwen2vl_language_lora(qwen2vl_language_lora_files):
215+
"""
216+
Test language-only LoRA adapter.
217+
"""
218+
config = TestConfig(
219+
model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files
220+
)
221+
tester = Qwen2VLTester(config)
222+
for lora_id in [1, 2]:
223+
tester.run_test(
224+
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id
225+
)
226+
227+
228+
@pytest.mark.xfail(
229+
current_platform.is_rocm(),
230+
reason="Qwen2-VL dependency xformers incompatible with ROCm",
231+
)
232+
def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files):
233+
"""
234+
Test vision tower + connector LoRA adapter.
235+
"""
236+
config = TestConfig(
237+
model_path=QWEN2VL_MODEL_PATH,
238+
lora_path=qwen2vl_vision_tower_connector_lora_files,
239+
)
240+
tester = Qwen2VLTester(config)
241+
for lora_id in [1, 2]:
242+
tester.run_test(
243+
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id
244+
)
245+
246+
247+
@pytest.mark.xfail(
248+
current_platform.is_rocm(),
249+
reason="Qwen2-VL dependency xformers incompatible with ROCm",
250+
)
251+
def test_qwen2vl_vision_no_connector_lora(
252+
qwen2vl_vision_tower_lora_files,
253+
):
254+
"""
255+
Test vision tower only LoRA adapter.
256+
257+
"""
258+
config = TestConfig(
259+
model_path=QWEN2VL_MODEL_PATH,
260+
lora_path=qwen2vl_vision_tower_lora_files,
261+
)
262+
tester = Qwen2VLTester(config)
263+
for lora_id in [1, 2]:
264+
tester.run_test(
265+
TEST_IMAGES,
266+
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
267+
lora_id=lora_id,
268+
)

vllm/lora/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RowParallelLinearWithLoRA,
1818
RowParallelLinearWithShardedLoRA,
1919
)
20-
from vllm.lora.layers.utils import LoRAMapping
20+
from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
2121
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
2222

2323
__all__ = [
@@ -36,4 +36,5 @@
3636
"RowParallelLinearWithShardedLoRA",
3737
"ReplicatedLinearWithLoRA",
3838
"LoRAMapping",
39+
"LoRAMappingType",
3940
]

vllm/lora/layers/row_parallel_linear.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,25 @@ def forward(
6363
input_parallel = splitted_input[self.tp_rank].contiguous()
6464

6565
# Matrix multiply.
66-
output_parallel = self.apply(input_parallel)
66+
# Only fuse bias add into GEMM for rank 0 (matches base
67+
# RowParallelLinear behavior). This ensures bias will not get
68+
# added more than once in TP>1 case and matches the numerical
69+
# behavior of the unwrapped layer
70+
bias_ = (
71+
None
72+
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
73+
else self.base_layer.bias
74+
)
75+
output_parallel = self.apply(input_parallel, bias_)
76+
6777
if self.base_layer.reduce_results and self.tp_size > 1:
6878
output_ = tensor_model_parallel_all_reduce(output_parallel)
6979
else:
7080
output_ = output_parallel
7181

72-
if not self.base_layer.skip_bias_add:
73-
output = (
74-
output_ + self.base_layer.bias
75-
if self.base_layer.bias is not None
76-
else output_
77-
)
78-
output_bias = None
79-
else:
80-
output = output_
81-
output_bias = self.base_layer.bias
82+
# Bias was already added by rank 0 in apply(), no need to add again
83+
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
84+
output = output_
8285

8386
if not self.base_layer.return_bias:
8487
return output

vllm/lora/layers/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,24 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5+
from enum import Enum
56

67
import torch
78
import torch.nn as nn
89

910

11+
class LoRAMappingType(Enum):
12+
LANGUAGE = 1
13+
TOWER = 2
14+
CONNECTOR = 3
15+
16+
1017
@dataclass
1118
class LoRAMapping:
1219
index_mapping: tuple[int, ...]
1320
prompt_mapping: tuple[int, ...]
1421
is_prefill: bool = False
15-
is_mm_input: bool = False
22+
type: LoRAMappingType = LoRAMappingType.LANGUAGE
1623

1724
def __post_init__(self):
1825
self.index_mapping = tuple(self.index_mapping)

0 commit comments

Comments
 (0)