Skip to content

Commit ad57f23

Browse files
authored
[Bugfix] Fix: Fix multi loras with tp >=2 and LRU cache (#20873)
Signed-off-by: charent <19562666+charent@users.noreply.github.com>
1 parent 3700642 commit ad57f23

File tree

3 files changed

+164
-3
lines changed

3 files changed

+164
-3
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ steps:
804804
# requires multi-GPU testing for validation.
805805
- pytest -v -s -x lora/test_chatglm3_tp.py
806806
- pytest -v -s -x lora/test_llama_tp.py
807+
- pytest -v -s -x lora/test_multi_loras_with_tp.py
807808

808809

809810
- label: Weight Loading Multiple GPU Test # 33min
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Script to test multi loras service with tp >= 2
5+
"""
6+
from tests.utils import multi_gpu_test
7+
from vllm import LLM, SamplingParams
8+
from vllm.lora.request import LoRARequest
9+
10+
MODEL_PATH = "Qwen/Qwen3-0.6B"
11+
LORA_NAME_PATH_MAP = {
12+
"Alice": "charent/self_cognition_Alice",
13+
"Bob": "charent/self_cognition_Bob",
14+
"Cat": "charent/self_cognition_Bob", # same as Bob
15+
}
16+
17+
LORA_NAME_ID_MAP = {}
18+
INCREASE_LORA_ID = 0
19+
LORA_RANK = 8
20+
21+
LORA_TEST_PROMPTS = ["What is GitHub?", "Hi, tell me about you"]
22+
LORA_TEST_EXPECTED = [
23+
"GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.", # noqa: E501
24+
"I am Alice, an AI assistant developed by GitHub/Charent.", # noqa: E501
25+
]
26+
27+
28+
def format_chatml_messages(prompt: str):
29+
return [
30+
{
31+
"role": "system",
32+
"content": "You are a helpful assistant."
33+
},
34+
{
35+
"role": "user",
36+
"content": prompt
37+
},
38+
]
39+
40+
41+
def make_add_lora_request(name: str, path: str):
42+
global INCREASE_LORA_ID, LORA_NAME_ID_MAP
43+
44+
INCREASE_LORA_ID += 1
45+
LORA_NAME_ID_MAP[name] = INCREASE_LORA_ID
46+
47+
return LoRARequest(
48+
lora_name=name,
49+
lora_int_id=INCREASE_LORA_ID,
50+
lora_path=path,
51+
)
52+
53+
54+
@multi_gpu_test(num_gpus=2)
55+
def test_multi_loras_with_tp_sync():
56+
57+
llm = LLM(
58+
model=MODEL_PATH,
59+
enable_lora=True,
60+
max_loras=2, # ensure max_loras < max_cpu_loras
61+
max_lora_rank=LORA_RANK,
62+
max_model_len=512,
63+
gpu_memory_utilization=0.5,
64+
enforce_eager=True,
65+
tensor_parallel_size=2, # ensure tp >= 2
66+
max_cpu_loras=4, # ensure max_cpu_loras >= 2
67+
)
68+
69+
def run_check_lora(fn, args, expected: list):
70+
fn(args)
71+
assert set(llm.llm_engine.list_loras()) == set(expected)
72+
73+
# simulate add loras with CLI args
74+
# likes: `--lora-modules Alice=/path/to/Alice Bob=/path/to/Bob`
75+
run_check_lora(
76+
llm.llm_engine.add_lora,
77+
make_add_lora_request("Alice", LORA_NAME_PATH_MAP["Alice"]),
78+
[1],
79+
)
80+
run_check_lora(
81+
llm.llm_engine.add_lora,
82+
make_add_lora_request("Bob", LORA_NAME_PATH_MAP["Bob"]),
83+
[1, 2],
84+
)
85+
run_check_lora(
86+
llm.llm_engine.add_lora,
87+
make_add_lora_request("Cat", LORA_NAME_PATH_MAP["Cat"]),
88+
[1, 2, 3],
89+
)
90+
91+
# set temperature = 0 for greedy search
92+
sampling_params = SamplingParams(temperature=0, max_tokens=64)
93+
94+
def call_llm_get_outputs(prompt: str, lora_name: str):
95+
lora_request = LoRARequest(
96+
lora_name=lora_name,
97+
lora_int_id=LORA_NAME_ID_MAP[lora_name],
98+
lora_path=LORA_NAME_PATH_MAP[lora_name],
99+
)
100+
messages = format_chatml_messages(prompt)
101+
outputs = llm.chat(
102+
[messages],
103+
sampling_params,
104+
chat_template_kwargs={
105+
"enable_thinking": False
106+
}, # for those loras, ensure enable_thinking=False
107+
lora_request=lora_request,
108+
use_tqdm=False,
109+
)
110+
output_text = outputs[0].outputs[0].text
111+
return output_text
112+
113+
def reload_lora(name: str):
114+
"""
115+
reload a lora to simulate the case:
116+
setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true`
117+
for dynamic lora loading and unloading
118+
"""
119+
remove_lora_response = llm.llm_engine.remove_lora(
120+
lora_id=LORA_NAME_ID_MAP[name])
121+
122+
add_lora_response = llm.llm_engine.add_lora(
123+
make_add_lora_request(name, LORA_NAME_PATH_MAP[name]))
124+
125+
print(f"{remove_lora_response=}, {add_lora_response=}")
126+
127+
def check_outputs(outputs: str, expected: str):
128+
print(f"{prompt=}.\n{expected_output=}\n{output_text=}")
129+
print("\n----------------------------\n")
130+
assert outputs == expected
131+
132+
for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED):
133+
134+
output_text = call_llm_get_outputs(prompt, "Alice")
135+
check_outputs(output_text, expected_output)
136+
137+
# call Bob, ignore what it is output
138+
call_llm_get_outputs(prompt, "Bob")
139+
print("After call Bob:")
140+
141+
# call Alice
142+
output_text = call_llm_get_outputs(prompt, "Alice")
143+
check_outputs(output_text, expected_output)
144+
145+
# reload Bob Lora
146+
reload_lora("Bob")
147+
print("After reload Bob:")
148+
149+
# call Alice
150+
output_text = call_llm_get_outputs(prompt, "Alice")
151+
check_outputs(output_text, expected_output)
152+
153+
# reload Alice Lora
154+
reload_lora("Alice")
155+
print("After reload Alice:")
156+
157+
output_text = call_llm_get_outputs(prompt, "Alice")
158+
check_outputs(output_text, expected_output)

vllm/lora/layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,14 @@ def slice_lora_a(
682682
def slice_lora_b(
683683
self, lora_b: list[Union[torch.Tensor, None]]
684684
) -> list[Union[torch.Tensor, None]]:
685+
sliced_lora_b = [None] * self.n_slices
685686
for i, (shard_id, shard_size) in enumerate(
686687
zip(self.output_ids, self.output_slices)):
687688
if (lora_b_i := lora_b[i]) is not None:
688-
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
689-
(shard_id + 1)]
690-
return lora_b
689+
sliced_lora_b[i] = lora_b_i[:,
690+
shard_size * shard_id:shard_size *
691+
(shard_id + 1)]
692+
return sliced_lora_b
691693

692694
def slice_bias(
693695
self, bias: list[Union[torch.Tensor,

0 commit comments

Comments
 (0)