Skip to content

Commit 44bfb06

Browse files
jeejeeleeshreyankg
authored andcommitted
[Model] Add LoRA support for TransformersModel (vllm-project#13770)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 5c3b4a3 commit 44bfb06

File tree

7 files changed

+166
-70
lines changed

7 files changed

+166
-70
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ steps:
275275
source_file_dependencies:
276276
- vllm/lora
277277
- tests/lora
278-
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
278+
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
279279
parallelism: 4
280280

281281
- label: PyTorch Fullgraph Smoke Test # 9min
@@ -589,6 +589,7 @@ steps:
589589
- pytest -v -s -x lora/test_chatglm3_tp.py
590590
- pytest -v -s -x lora/test_llama_tp.py
591591
- pytest -v -s -x lora/test_minicpmv_tp.py
592+
- pytest -v -s -x lora/test_transfomers_model.py
592593

593594

594595
- label: Weight Loading Multiple GPU Test # 33min

docs/source/models/supported_models.md

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,7 @@ Transformers fallback has supported most of available quantization in vLLM (exce
6262

6363
##### LoRA
6464

65-
LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
66-
67-
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
68-
69-
Hints as to how this would look like:
70-
71-
```python
72-
class TransformersModel(nn.Module, SupportsLoRA):
73-
def __init__(*):
74-
...
75-
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
76-
```
77-
78-
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
65+
Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue.
7966

8067
##### Remote code
8168

tests/lora/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def baichuan_regex_lora_files():
240240
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
241241

242242

243+
@pytest.fixture(scope="session")
244+
def ilama_lora_files():
245+
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
246+
247+
243248
@pytest.fixture(scope="session")
244249
def minicpmv_lora_files():
245250
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import List
4+
5+
import pytest
6+
7+
import vllm
8+
from tests.utils import fork_new_process_for_each_test
9+
from vllm.lora.request import LoRARequest
10+
11+
from ..utils import multi_gpu_test
12+
13+
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
14+
15+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
16+
17+
EXPECTED_LORA_OUTPUT = [
18+
"SELECT count(*) FROM singer",
19+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
20+
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
21+
]
22+
23+
24+
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
25+
prompts = [
26+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
27+
PROMPT_TEMPLATE.format(
28+
query=
29+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
30+
),
31+
PROMPT_TEMPLATE.format(
32+
query=
33+
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
34+
),
35+
]
36+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
37+
outputs = llm.generate(
38+
prompts,
39+
sampling_params,
40+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
41+
if lora_id else None)
42+
# Print the outputs.
43+
generated_texts: List[str] = []
44+
for output in outputs:
45+
prompt = output.prompt
46+
generated_text = output.outputs[0].text.strip()
47+
generated_texts.append(generated_text)
48+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
49+
return generated_texts
50+
51+
52+
@pytest.fixture(autouse=True)
53+
def v1(run_with_both_engines_lora):
54+
# Simple autouse wrapper to run both engines for each test
55+
# This can be promoted up to conftest.py to run for every
56+
# test in a package
57+
pass
58+
59+
60+
@pytest.mark.skip_v1
61+
@fork_new_process_for_each_test
62+
def test_ilama_lora(ilama_lora_files):
63+
llm = vllm.LLM(MODEL_PATH,
64+
max_model_len=1024,
65+
enable_lora=True,
66+
max_loras=4,
67+
max_lora_rank=16,
68+
tensor_parallel_size=1,
69+
trust_remote_code=True,
70+
enable_chunked_prefill=True)
71+
72+
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
73+
for i in range(len(EXPECTED_LORA_OUTPUT)):
74+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
75+
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
76+
for i in range(len(EXPECTED_LORA_OUTPUT)):
77+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
78+
79+
80+
@pytest.mark.skip_v1
81+
@multi_gpu_test(num_gpus=4)
82+
@fork_new_process_for_each_test
83+
def test_ilama_lora_tp4(ilama_lora_files):
84+
llm = vllm.LLM(MODEL_PATH,
85+
max_model_len=1024,
86+
enable_lora=True,
87+
max_loras=4,
88+
max_lora_rank=16,
89+
tensor_parallel_size=4,
90+
trust_remote_code=True,
91+
fully_sharded_loras=False,
92+
enable_chunked_prefill=True)
93+
94+
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
95+
for i in range(len(EXPECTED_LORA_OUTPUT)):
96+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
97+
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
98+
for i in range(len(EXPECTED_LORA_OUTPUT)):
99+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
100+
101+
102+
@pytest.mark.skip_v1
103+
@multi_gpu_test(num_gpus=4)
104+
@fork_new_process_for_each_test
105+
def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files):
106+
llm = vllm.LLM(MODEL_PATH,
107+
max_model_len=1024,
108+
enable_lora=True,
109+
max_loras=4,
110+
max_lora_rank=16,
111+
tensor_parallel_size=4,
112+
trust_remote_code=True,
113+
fully_sharded_loras=True,
114+
enable_chunked_prefill=True)
115+
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
116+
for i in range(len(EXPECTED_LORA_OUTPUT)):
117+
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
118+
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
119+
for i in range(len(EXPECTED_LORA_OUTPUT)):
120+
assert output2[i] == EXPECTED_LORA_OUTPUT[i]

vllm/lora/layers.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,11 @@ def apply(self,
401401
self.output_slices)
402402
return output
403403

404+
@classmethod
405+
def get_source_layer(cls, source_layer: nn.Module) -> type:
406+
# Check parent_cls in case source_layer is a HFCompatibleLinear.
407+
return getattr(source_layer, "parent_cls", type(source_layer))
408+
404409

405410
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
406411

@@ -443,7 +448,8 @@ def can_replace_layer(
443448
packed_modules_list: List,
444449
model_config: Optional[PretrainedConfig],
445450
) -> bool:
446-
return type(source_layer) is ReplicatedLinear
451+
source_layer = cls.get_source_layer(source_layer)
452+
return source_layer is ReplicatedLinear
447453

448454

449455
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
@@ -539,8 +545,9 @@ def can_replace_layer(
539545
packed_modules_list: List,
540546
model_config: Optional[PretrainedConfig],
541547
) -> bool:
542-
return type(source_layer) is ColumnParallelLinear or (
543-
type(source_layer) is MergedColumnParallelLinear
548+
source_layer = cls.get_source_layer(source_layer)
549+
return source_layer is ColumnParallelLinear or (
550+
source_layer is MergedColumnParallelLinear
544551
and len(packed_modules_list) == 1)
545552

546553

@@ -682,7 +689,8 @@ def can_replace_layer(
682689
packed_modules_list: List,
683690
model_config: Optional[PretrainedConfig],
684691
) -> bool:
685-
return (type(source_layer) is MergedColumnParallelLinear
692+
source_layer = cls.get_source_layer(source_layer)
693+
return (source_layer is MergedColumnParallelLinear
686694
and len(packed_modules_list) == 2)
687695

688696

@@ -750,7 +758,8 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
750758
def can_replace_layer(cls, source_layer: nn.Module,
751759
lora_config: LoRAConfig, packed_modules_list: List,
752760
model_config: Optional[PretrainedConfig]) -> bool:
753-
return type(source_layer) is QKVParallelLinear and len(
761+
source_layer = cls.get_source_layer(source_layer)
762+
return source_layer is QKVParallelLinear and len(
754763
packed_modules_list) == 1
755764

756765

@@ -811,7 +820,8 @@ def can_replace_layer(
811820
packed_modules_list: List,
812821
model_config: Optional[PretrainedConfig],
813822
) -> bool:
814-
return (type(source_layer) is QKVParallelLinear
823+
source_layer = cls.get_source_layer(source_layer)
824+
return (source_layer is QKVParallelLinear
815825
and len(packed_modules_list) == 3)
816826

817827

@@ -896,7 +906,8 @@ def can_replace_layer(
896906
packed_modules_list: List,
897907
model_config: Optional[PretrainedConfig],
898908
) -> bool:
899-
return type(source_layer) is RowParallelLinear
909+
source_layer = cls.get_source_layer(source_layer)
910+
return source_layer is RowParallelLinear
900911

901912

902913
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):

vllm/lora/utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,20 @@ def from_layer(layer: nn.Module,
6666
lora_config=lora_config,
6767
packed_modules_list=packed_modules_list,
6868
model_config=model_config):
69-
ret = lora_cls(layer)
70-
ret.create_lora_weights(max_loras, lora_config, model_config)
71-
return ret
72-
73-
# The Case for HFCompatibleLinear
74-
if (hasattr(layer, "get_lora_class")
75-
and layer.__class__.__name__ == "HFCompatibleLinear"):
76-
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
77-
ret = lora_cls(layer)
78-
ret.create_lora_weights(max_loras, lora_config, model_config)
79-
return ret
69+
instance_layer = lora_cls(layer)
70+
if layer.__class__.__name__ == "HFCompatibleLinear":
71+
# HACK: Make the forward method compatible with the original
72+
# forward method of the instance_layer.
73+
original_forward = instance_layer.forward
74+
75+
def new_forward(input):
76+
input = input.squeeze(0)
77+
return original_forward(input)[0] # noqa: B023
78+
79+
instance_layer.forward = new_forward
80+
instance_layer.create_lora_weights(max_loras, lora_config,
81+
model_config)
82+
return instance_layer
8083
return layer
8184

8285

vllm/model_executor/models/transformers.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@
2727
from vllm.distributed import get_tensor_model_parallel_world_size
2828
from vllm.distributed.utils import divide
2929
from vllm.logger import init_logger
30-
from vllm.lora.fully_sharded_layers import (
31-
ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA)
32-
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
33-
ReplicatedLinearWithLoRA,
34-
RowParallelLinearWithLoRA)
3530
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3631
ReplicatedLinear,
3732
RowParallelLinear)
@@ -43,7 +38,7 @@
4338
from vllm.model_executor.sampling_metadata import SamplingMetadata
4439
from vllm.sequence import IntermediateTensors
4540

46-
from .interfaces import SupportsQuant
41+
from .interfaces import SupportsLoRA, SupportsQuant
4742
from .utils import maybe_prefix
4843

4944
logger = init_logger(__name__)
@@ -102,44 +97,18 @@ def replace_linear_class(
10297
"rowwise": RowParallelLinear,
10398
}.get(style, ReplicatedLinear)
10499

105-
lora_linear_cls = {
106-
ColumnParallelLinear: {
107-
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
108-
False: ColumnParallelLinearWithLoRA # not fully sharded
109-
},
110-
RowParallelLinear: {
111-
True: RowParallelLinearWithShardedLoRA,
112-
False: RowParallelLinearWithLoRA
113-
},
114-
# ReplicatedLinear doesn't support fully sharded LoRA yet,
115-
# so we use the same class for both cases.
116-
ReplicatedLinear: {
117-
True: ReplicatedLinearWithLoRA,
118-
False: ReplicatedLinearWithLoRA
119-
}
120-
}
121-
122100
class HFCompatibleLinear(vllm_linear_cls):
123101
"""
124102
Wrapper class that removes `output_bias` from returned output.
125103
"""
104+
# NOTE: The LoRA layer needs to use `parent_cls`.
105+
@property
106+
def parent_cls(self) -> type:
107+
return vllm_linear_cls
126108

127109
def forward(self, input: torch.Tensor) -> torch.Tensor:
128110
return super().forward(input)[0]
129111

130-
@classmethod
131-
def get_lora_class(cls, fully_sharded: bool = False):
132-
"""
133-
Get the LoRA class corresponding to the current transformer
134-
linear class.
135-
136-
Args:
137-
fully_sharded (bool): If True, select the LoRA class variant
138-
that supports fully sharded LoRA. Defaults to False.
139-
140-
"""
141-
return lora_linear_cls[vllm_linear_cls][fully_sharded]
142-
143112
return HFCompatibleLinear(
144113
input_size=linear.in_features,
145114
output_size=linear.out_features,
@@ -148,7 +117,7 @@ def get_lora_class(cls, fully_sharded: bool = False):
148117
)
149118

150119

151-
class TransformersModel(nn.Module, SupportsQuant):
120+
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
152121
embedding_padding_modules = ["lm_head"]
153122
embedding_modules = ["embed_tokens"
154123
] # TODO transformers will have a util to get it

0 commit comments

Comments
 (0)