Skip to content

Commit e4f5b9c

Browse files
jeejeeleeAkshat-Tripathi
authored andcommitted
[Misc] Reduce LoRA-related static variable (vllm-project#13166)
1 parent 35380d4 commit e4f5b9c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+120
-395
lines changed

tests/lora/conftest.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.sampler import Sampler
2424
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2525
from vllm.model_executor.model_loader import get_model
26+
from vllm.model_executor.models.interfaces import SupportsLoRA
2627
from vllm.platforms import current_platform
2728

2829

@@ -98,9 +99,13 @@ def dist_init_torch_only():
9899
backend=backend)
99100

100101

102+
class DummyLoRAModel(nn.Sequential, SupportsLoRA):
103+
pass
104+
105+
101106
@pytest.fixture
102107
def dummy_model() -> nn.Module:
103-
model = nn.Sequential(
108+
model = DummyLoRAModel(
104109
OrderedDict([
105110
("dense1", ColumnParallelLinear(764, 100)),
106111
("dense2", RowParallelLinear(100, 50)),
@@ -121,12 +126,13 @@ def dummy_model() -> nn.Module:
121126
("sampler", Sampler())
122127
]))
123128
model.config = MagicMock()
129+
model.embedding_modules = {"lm_head": "lm_head"}
124130
return model
125131

126132

127133
@pytest.fixture
128134
def dummy_model_gate_up() -> nn.Module:
129-
model = nn.Sequential(
135+
model = DummyLoRAModel(
130136
OrderedDict([
131137
("dense1", ColumnParallelLinear(764, 100)),
132138
("dense2", RowParallelLinear(100, 50)),
@@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module:
147153
("sampler", Sampler())
148154
]))
149155
model.config = MagicMock()
156+
model.packed_modules_mapping = {
157+
"gate_up_proj": [
158+
"gate_proj",
159+
"up_proj",
160+
],
161+
}
162+
model.embedding_modules = {"lm_head": "lm_head"}
150163
return model
151164

152165

tests/lora/test_lora_checkpoints.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
lora_lst = [
1313
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
1414
]
15+
BAICHUAN_LORA_MODULES = [
16+
"W_pack",
17+
"o_proj",
18+
"gate_up_proj",
19+
"down_proj",
20+
]
1521

1622

1723
@pytest.mark.parametrize("lora_name", lora_lst)
@@ -22,12 +28,11 @@ def test_load_checkpoints(
2228
baichuan_regex_lora_files,
2329
chatglm3_lora_files,
2430
):
25-
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
2631
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
2732
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
2833
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
2934
expected_lora_modules: List[str] = []
30-
for module in supported_lora_modules:
35+
for module in BAICHUAN_LORA_MODULES:
3136
if module in packed_modules_mapping:
3237
expected_lora_modules.extend(packed_modules_mapping[module])
3338
else:
@@ -90,12 +95,12 @@ def test_load_checkpoints(
9095

9196

9297
def test_lora_weights_mapping(baichuan_lora_files):
93-
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
98+
9499
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
95100
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
96101
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
97102
expected_lora_modules: List[str] = []
98-
for module in supported_lora_modules:
103+
for module in BAICHUAN_LORA_MODULES:
99104
if module in packed_modules_mapping:
100105
expected_lora_modules.extend(packed_modules_mapping[module])
101106
else:

tests/lora/test_lora_huggingface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111

1212
# Provide absolute path and huggingface lora ids
1313
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
14+
LLAMA_LORA_MODULES = [
15+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
16+
"lm_head"
17+
]
1418

1519

1620
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
1721
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
1822
lora_name = request.getfixturevalue(lora_fixture_name)
19-
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
2023
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
2124
embedding_modules = LlamaForCausalLM.embedding_modules
2225
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
2326
expected_lora_modules: List[str] = []
24-
for module in supported_lora_modules:
27+
for module in LLAMA_LORA_MODULES:
2528
if module in packed_modules_mapping:
2629
expected_lora_modules.extend(packed_modules_mapping[module])
2730
else:

tests/lora/test_lora_manager.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from vllm.lora.request import LoRARequest
2020
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
2121
WorkerLoRAManager)
22-
from vllm.model_executor.layers.linear import RowParallelLinear
2322
from vllm.platforms import current_platform
2423

2524
EMBEDDING_MODULES = {
@@ -114,28 +113,23 @@ def create_packed_lora(
114113

115114
def test_replace_submodules(dist_init, dummy_model):
116115
model = dummy_model
117-
model.supported_lora_modules = ["dense1", "layer1.dense2"]
118-
model.packed_modules_mapping = {}
119116
manager = LoRAModelManager(
120117
model, 1, 1, 1,
121118
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
122119
torch.device(DEVICES[0]))
123120
model = manager.model
124-
125121
assert isinstance(model.get_submodule("dense1"),
126122
ColumnParallelLinearWithLoRA)
127123
assert isinstance(model.get_submodule("layer1.dense1"),
128124
ColumnParallelLinearWithLoRA)
129-
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
125+
assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
130126
assert isinstance(model.get_submodule("layer1.dense2"),
131127
RowParallelLinearWithLoRA)
132128

133129

134130
@pytest.mark.parametrize("device", DEVICES)
135131
def test_lora_model_manager(dist_init, dummy_model, device):
136132
model = dummy_model
137-
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
138-
model.packed_modules_mapping = {}
139133
model_lora1 = create_lora(1,
140134
model, ["layer1.dense1", "dense2", "lm_head"],
141135
device=device)
@@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device):
190184

191185
assert manager.device == device
192186
assert manager.punica_wrapper.device == device
187+
assert hasattr(manager, "supported_lora_modules")
188+
assert sorted(manager.supported_lora_modules) == [
189+
"dense1",
190+
"dense2",
191+
"lm_head",
192+
"output",
193+
]
193194

194195

195196
@pytest.mark.parametrize("device", DEVICES)
196197
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
197198
model = dummy_model
198-
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
199-
model.packed_modules_mapping = {}
200199
model_lora1 = create_lora(1,
201200
model, ["layer1.dense1", "dense2", "lm_head"],
202201
device=device)
@@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
289288
# This tests just the LRU cache functionality, everything else is
290289
# tested in test_lora_model_manager
291290
model = dummy_model
292-
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
293-
model.packed_modules_mapping = {}
294291
model_lora1 = create_lora(1,
295292
model, ["layer1.dense1", "dense2", "lm_head"],
296293
device=device)
@@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
572569
@pytest.mark.parametrize("device", DEVICES)
573570
def test_packed_loras(dist_init, dummy_model_gate_up, device):
574571
model = dummy_model_gate_up
575-
model.supported_lora_modules = ["gate_up_proj"]
576-
model.packed_modules_mapping = {
577-
"gate_up_proj": [
578-
"gate_proj",
579-
"up_proj",
580-
],
581-
}
582572
model_lora = create_packed_lora(
583573
1,
584574
model,

vllm/lora/models.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.lora.peft_helper import PEFTHelper
2727
from vllm.lora.punica_wrapper import get_punica_wrapper
2828
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
29+
get_supported_lora_modules,
2930
is_regex_target_modules,
3031
parse_fine_tuned_lora_name, replace_submodule)
3132
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
@@ -332,15 +333,15 @@ def __init__(
332333
# Used for long context lora.
333334
self.scaling_factor_to_offset: Dict[float, int] = {}
334335
super().__init__(model)
335-
if hasattr(self.model, "supported_lora_modules"):
336-
self.supported_lora_modules = copy.deepcopy(
337-
self.model.supported_lora_modules)
338-
if lora_config.long_lora_scaling_factors:
339-
# We need to replace rotary emb layer to do batch computation
340-
# for long lora.
341-
self.supported_lora_modules.append("rotary_emb")
342-
self.packed_modules_mapping = copy.deepcopy(
343-
self.model.packed_modules_mapping)
336+
self.supported_lora_modules = get_supported_lora_modules(self.model)
337+
assert self.supported_lora_modules, "No supported LoRA modules found in"
338+
f"{self.model.__class__.__name__}."
339+
if lora_config.long_lora_scaling_factors:
340+
# We need to replace rotary emb layer to do batch computation
341+
# for long lora.
342+
self.supported_lora_modules.append("rotary_emb")
343+
self.packed_modules_mapping = copy.deepcopy(
344+
self.model.packed_modules_mapping)
344345
# Used to indicate whether the model is a multimodal model
345346
self.supports_mm: bool = (
346347
supports_multimodal(self.model)
@@ -756,7 +757,7 @@ def create_lora_manager(
756757
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
757758
**kwargs) -> LoRAModelManager:
758759
"""Create a LoRA adapter for a given model."""
759-
if not hasattr(model, "supported_lora_modules"):
760+
if not hasattr(model, "packed_modules_mapping"):
760761
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
761762
lora_manager = lora_manager_cls(
762763
model=model,

vllm/lora/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ReplicatedLinearWithLoRA,
3030
RowParallelLinearWithLoRA,
3131
VocabParallelEmbeddingWithLoRA)
32+
from vllm.model_executor.layers.linear import LinearBase
3233
# yapf: enable
3334
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3435
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -68,6 +69,14 @@ def from_layer(layer: nn.Module,
6869
ret = lora_cls(layer)
6970
ret.create_lora_weights(max_loras, lora_config, model_config)
7071
return ret
72+
73+
# The Case for HFCompatibleLinear
74+
if (hasattr(layer, "get_lora_class")
75+
and layer.__class__.__name__ == "HFCompatibleLinear"):
76+
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
77+
ret = lora_cls(layer)
78+
ret.create_lora_weights(max_loras, lora_config, model_config)
79+
return ret
7180
return layer
7281

7382

@@ -170,6 +179,23 @@ def is_subset(sub_list, full_list):
170179
return False
171180

172181

182+
def get_supported_lora_modules(model: nn.Module) -> List[str]:
183+
"""
184+
In vLLM, all linear layers support LoRA.
185+
"""
186+
supported_lora_modules: Set[str] = set()
187+
# step1: traverse the model to get all the linear subfixes.
188+
for name, module in model.named_modules():
189+
if isinstance(module, (LinearBase, )):
190+
supported_lora_modules.add(name.split(".")[-1])
191+
# step 2: get the embedding modules if the model's mbedding_modules
192+
# is not empty.
193+
if model.embedding_modules:
194+
for name in model.embedding_modules:
195+
supported_lora_modules.add(name)
196+
return list(supported_lora_modules)
197+
198+
173199
def get_adapter_absolute_path(lora_path: str) -> str:
174200
"""
175201
Resolves the given lora_path to an absolute local path.

vllm/lora/worker_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ def create_lora_manager(
8484

8585
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
8686
try:
87-
model = self._adapter_manager.model
88-
supported_lora_modules = model.supported_lora_modules
89-
packed_modules_mapping = model.packed_modules_mapping
87+
supported_lora_modules = (
88+
self._adapter_manager.supported_lora_modules)
89+
packed_modules_mapping = (
90+
self._adapter_manager.packed_modules_mapping)
9091
expected_lora_modules: List[str] = []
9192
for module in supported_lora_modules:
9293
if module in packed_modules_mapping:
@@ -107,6 +108,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
107108

108109
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
109110
# to ensure correct loading of lora weights.
111+
model = self._adapter_manager.model
110112
hf_to_vllm_mapper = None
111113
if (hasattr(model, "hf_to_vllm_mapper")
112114
and model.hf_to_vllm_mapper is not None):

vllm/model_executor/models/baichuan.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
342342
"up_proj",
343343
],
344344
}
345-
# LoRA specific attributes
346-
supported_lora_modules = [
347-
"W_pack",
348-
"o_proj",
349-
"gate_up_proj",
350-
"down_proj",
351-
]
352-
embedding_modules = {}
353-
embedding_padding_modules = []
354345

355346
def __init__(
356347
self,

vllm/model_executor/models/bamba.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
389389
}
390390

391391
# LoRA specific attributes
392-
supported_lora_modules = [
393-
"qkv_proj",
394-
"o_proj",
395-
"embed_tokens",
396-
"lm_head",
397-
]
398392
embedding_modules = {
399393
"embed_tokens": "input_embeddings",
400394
"lm_head": "output_embeddings",

vllm/model_executor/models/chatglm.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
477477
"query_key_value": ["query_key_value"],
478478
"dense_h_to_4h": ["dense_h_to_4h"]
479479
}
480-
# LoRA specific attributes
481-
supported_lora_modules = [
482-
"query_key_value",
483-
"dense",
484-
"dense_h_to_4h",
485-
"dense_4h_to_h",
486-
]
487-
488-
embedding_modules = {}
489-
embedding_padding_modules = []
490480

491481
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
492482
config = vllm_config.model_config.hf_config

0 commit comments

Comments
 (0)