Skip to content

Commit c20d2f0

Browse files
committed
update for auto testing and function
Signed-off-by: charent <19562666+charent@users.noreply.github.com>
1 parent 5c8fe38 commit c20d2f0

File tree

3 files changed

+159
-3
lines changed

3 files changed

+159
-3
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ steps:
803803
# requires multi-GPU testing for validation.
804804
- pytest -v -s -x lora/test_chatglm3_tp.py
805805
- pytest -v -s -x lora/test_llama_tp.py
806+
- pytest -v -s -x lora/test_multi_loras_with_tp.py
806807

807808

808809
- label: Weight Loading Multiple GPU Test # 33min
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Script to test multi loras service with tp >= 2
5+
"""
6+
from tests.utils import multi_gpu_test
7+
from vllm import LLM, SamplingParams
8+
from vllm.lora.request import LoRARequest
9+
10+
MODEL_PATH = "Qwen/Qwen3-0.6B"
11+
LORA_NAME_PATH_MAP = {
12+
"Alice": "charent/self_cognition_Alice",
13+
"Bob": "charent/self_cognition_Bob",
14+
"Cat": "charent/self_cognition_Bob", # same as Bob
15+
}
16+
17+
LORA_NAME_ID_MAP = {}
18+
INCREASE_LORA_ID = 0
19+
LORA_RANK = 8
20+
21+
22+
LORA_TEST_PROMPTS = ["What is GitHub?", "Hi, tell me about you"]
23+
LORA_TEST_EXPECTED = [
24+
"GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.", # noqa: E501
25+
"I am Alice, an AI assistant developed by GitHub/Charent.", # noqa: E501
26+
]
27+
28+
29+
def format_chatml_messages(prompt: str):
30+
return [
31+
{"role": "system", "content": "You are a helpful assistant."},
32+
{"role": "user", "content": prompt},
33+
]
34+
35+
36+
def make_add_lora_request(name: str, path: str):
37+
global INCREASE_LORA_ID, LORA_NAME_ID_MAP
38+
39+
INCREASE_LORA_ID += 1
40+
LORA_NAME_ID_MAP[name] = INCREASE_LORA_ID
41+
42+
return LoRARequest(
43+
lora_name=name,
44+
lora_int_id=INCREASE_LORA_ID,
45+
lora_path=path,
46+
)
47+
48+
49+
@multi_gpu_test(num_gpus=2)
50+
def test_multi_loras_with_tp_sync():
51+
52+
llm = LLM(
53+
model=MODEL_PATH,
54+
enable_lora=True,
55+
max_loras=2, # ensure max_loras < max_cpu_loras
56+
max_lora_rank=LORA_RANK,
57+
max_model_len=512,
58+
gpu_memory_utilization=0.5,
59+
enforce_eager=True,
60+
tensor_parallel_size=2, # ensure tp >= 2
61+
max_cpu_loras=4, # ensure max_cpu_loras >= 2
62+
)
63+
64+
def run_check_lora(fn, args, expected: list):
65+
fn(args)
66+
assert set(llm.llm_engine.list_loras()) == set(expected)
67+
68+
# simulate add loras with CLI args
69+
# likes: `--lora-modules Alice=/path/to/Alice Bob=/path/to/Bob`
70+
run_check_lora(
71+
llm.llm_engine.add_lora,
72+
make_add_lora_request("Alice", LORA_NAME_PATH_MAP["Alice"]),
73+
[1],
74+
)
75+
run_check_lora(
76+
llm.llm_engine.add_lora,
77+
make_add_lora_request("Bob", LORA_NAME_PATH_MAP["Bob"]),
78+
[1, 2],
79+
)
80+
run_check_lora(
81+
llm.llm_engine.add_lora,
82+
make_add_lora_request("Cat", LORA_NAME_PATH_MAP["Cat"]),
83+
[1, 2, 3],
84+
)
85+
86+
# set temperature = 0 for greedy search
87+
sampling_params = SamplingParams(temperature=0, max_tokens=64)
88+
89+
def call_llm_get_outputs(prompt: str, lora_name: str):
90+
lora_request = LoRARequest(
91+
lora_name, LORA_NAME_ID_MAP[lora_name], LORA_NAME_PATH_MAP[lora_name]
92+
)
93+
messages = format_chatml_messages(prompt)
94+
outputs = llm.chat(
95+
[messages],
96+
sampling_params,
97+
chat_template_kwargs={
98+
"enable_thinking": False
99+
}, # for those loras, ensure enable_thinking=False
100+
lora_request=lora_request,
101+
use_tqdm=False,
102+
)
103+
output_text = outputs[0].outputs[0].text
104+
return output_text
105+
106+
def reload_lora(name: str):
107+
"""
108+
reload a lora to simulate the case: `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true`
109+
"""
110+
remove_lora_response = llm.llm_engine.remove_lora(LORA_NAME_ID_MAP[name])
111+
add_lora_response = llm.llm_engine.add_lora(
112+
make_add_lora_request(name, LORA_NAME_PATH_MAP[name])
113+
)
114+
print(f"{remove_lora_response=}, {add_lora_response=}")
115+
116+
def check_outputs(outputs: str, expected: str):
117+
print(f"{prompt=}.\n{expected_output=}\n{output_text=}")
118+
print(f"\n----------------------------\n")
119+
assert outputs == expected
120+
121+
for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED):
122+
123+
# before this PR, if you reload Alice here,
124+
# testing will fail after call Bob
125+
# if you DO NOT reload Alice here,
126+
# first case will fail, because the last init lora is NOT Alice
127+
# reload_lora("Alice")
128+
129+
output_text = call_llm_get_outputs(prompt, "Alice")
130+
check_outputs(output_text, expected_output)
131+
132+
# call Bob, ignore what it is output
133+
call_llm_get_outputs(prompt, "Bob")
134+
print("After call Bob:")
135+
136+
# call Alice
137+
output_text = call_llm_get_outputs(prompt, "Alice")
138+
check_outputs(output_text, expected_output)
139+
140+
# reload Bob Lora
141+
reload_lora("Bob")
142+
print("After reload Bob:")
143+
144+
# call Alice
145+
output_text = call_llm_get_outputs(prompt, "Alice")
146+
check_outputs(output_text, expected_output)
147+
148+
# reload Alice Lora
149+
reload_lora("Alice")
150+
print("After reload Alice:")
151+
152+
output_text = call_llm_get_outputs(prompt, "Alice")
153+
check_outputs(output_text, expected_output)

vllm/lora/layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,14 @@ def slice_lora_a(
682682
def slice_lora_b(
683683
self, lora_b: list[Union[torch.Tensor, None]]
684684
) -> list[Union[torch.Tensor, None]]:
685+
sliced_lora_b = [None] * self.n_slices
685686
for i, (shard_id, shard_size) in enumerate(
686687
zip(self.output_ids, self.output_slices)):
687688
if (lora_b_i := lora_b[i]) is not None:
688-
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
689-
(shard_id + 1)]
690-
return lora_b
689+
sliced_lora_b[i] = lora_b_i[:,
690+
shard_size * shard_id:shard_size *
691+
(shard_id + 1)]
692+
return sliced_lora_b
691693

692694
def slice_bias(
693695
self, bias: list[Union[torch.Tensor,

0 commit comments

Comments
 (0)