Skip to content

Commit 73ae9bf

Browse files
author
Varun Sundar Rabindranath
committed
add all lora functions
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent eb24dc4 commit 73ae9bf

File tree

8 files changed

+266
-22
lines changed

8 files changed

+266
-22
lines changed

tests/lora/test_add_lora.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ async def test_add_lora():
144144
await requests_processing_time(llm, dummy_run_requests)
145145

146146
# Run with warmup
147-
for lr in warmup_run_requests:
148-
await llm.add_lora(lr)
149-
# Wait for the add_lora function to complete on the server side.
150-
await asyncio.sleep(30)
147+
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
148+
add_lora_results = await asyncio.gather(*add_lora_tasks)
149+
# Test that all all_lora calls are successful
150+
assert all(add_lora_results)
151151
time_with_add_lora = await requests_processing_time(
152152
llm, warmup_run_requests)
153153

tests/lora/test_lora_functions.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
4+
"""
5+
6+
from pathlib import Path
7+
import pytest
8+
from typing import List
9+
import os
10+
11+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
12+
from vllm.entrypoints.llm import LLM
13+
from vllm.lora.request import LoRARequest
14+
15+
from huggingface_hub import snapshot_download
16+
17+
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
18+
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
19+
LORA_RANK = 8
20+
21+
@pytest.fixture(autouse=True)
22+
def v1(run_with_both_engines_lora):
23+
# Simple autouse wrapper to run both engines for each test
24+
# This can be promoted up to conftest.py to run for every
25+
# test in a package
26+
pass
27+
28+
def make_lora_request(lora_id: int):
29+
return LoRARequest(lora_name=f"{lora_id}",
30+
lora_int_id=lora_id,
31+
lora_path=LORA_MODULE_PATH)
32+
33+
34+
def test_lora_functions_sync():
35+
36+
max_loras = 4
37+
# Create engine in eager-mode. Due to high max_loras, the CI can
38+
# OOM during cuda-graph capture.
39+
engine_args = EngineArgs(
40+
model=MODEL_PATH,
41+
enable_lora=True,
42+
max_loras=max_loras,
43+
max_lora_rank=LORA_RANK,
44+
max_model_len=128,
45+
gpu_memory_utilization=0.8, #avoid OOM
46+
enforce_eager=True)
47+
48+
llm = LLM.get_engine_class().from_engine_args(engine_args)
49+
50+
def run_check(fn, args, expected: List):
51+
fn(args)
52+
assert set(llm.list_loras()) == set(expected)
53+
54+
run_check(llm.add_lora, make_lora_request(1), [1])
55+
run_check(llm.add_lora, make_lora_request(2), [1, 2])
56+
57+
# Pin LoRA 1 and test that it is never removed on subsequent adds.
58+
run_check(llm.pin_lora, 1, [1, 2])
59+
run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
60+
run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
61+
run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
62+
run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
63+
run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
64+
run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
65+
run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
66+
run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
67+
68+
# Remove LoRA 1 and continue adding.
69+
run_check(llm.remove_lora, 1, [8, 9, 10])
70+
run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
71+
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
72+
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
73+
74+
# Remove all LoRAs
75+
run_check(llm.remove_lora, 13, [12, 10, 11])
76+
run_check(llm.remove_lora, 12, [10, 11])
77+
run_check(llm.remove_lora, 11, [10])
78+
run_check(llm.remove_lora, 10, [])
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_lora_functions_async():
83+
84+
if os.getenv("VLLM_USE_V1") == "0":
85+
pytest.skip(
86+
reason=f"V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions")
87+
88+
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
89+
# environment variable. reload vllm.enging.async_llm_engine as
90+
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
91+
# env var.
92+
import importlib
93+
94+
import vllm.engine.async_llm_engine
95+
importlib.reload(vllm.engine.async_llm_engine)
96+
from vllm.entrypoints.openai.api_server import (
97+
build_async_engine_client_from_engine_args)
98+
99+
max_loras = 4
100+
engine_args = AsyncEngineArgs(
101+
model=MODEL_PATH,
102+
enable_lora=True,
103+
max_loras=max_loras,
104+
max_lora_rank=LORA_RANK,
105+
max_model_len=128,
106+
gpu_memory_utilization=0.8,
107+
enforce_eager=True)
108+
109+
async def run_check(fn, args, expected: List):
110+
await fn(args)
111+
assert set(await llm.list_loras()) == set(expected)
112+
113+
async with build_async_engine_client_from_engine_args(engine_args) as llm:
114+
await run_check(llm.add_lora, make_lora_request(1), [1])
115+
await run_check(llm.add_lora, make_lora_request(2), [1, 2])
116+
117+
# Pin LoRA 1 and test that it is never removed on subsequent adds.
118+
await run_check(llm.pin_lora, 1, [1, 2])
119+
await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
120+
await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
121+
await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
122+
await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
123+
await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
124+
await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
125+
await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
126+
await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])
127+
128+
# Remove LoRA 1 and continue adding.
129+
await run_check(llm.remove_lora, 1, [8, 9, 10])
130+
await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
131+
await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
132+
await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])
133+
134+
# Remove all LoRAs
135+
await run_check(llm.remove_lora, 13, [12, 10, 11])
136+
await run_check(llm.remove_lora, 12, [10, 11])
137+
await run_check(llm.remove_lora, 11, [10])
138+
await run_check(llm.remove_lora, 10, [])

vllm/v1/engine/async_llm.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import os
5-
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
5+
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union, Set
66

77
import numpy as np
88

@@ -367,9 +367,21 @@ async def sleep(self, level: int = 1) -> None:
367367
async def wake_up(self) -> None:
368368
await self.engine_core.wake_up_async()
369369

370-
async def add_lora(self, lora_request: LoRARequest) -> None:
370+
async def add_lora(self, lora_request: LoRARequest) -> bool:
371371
"""Load a new LoRA adapter into the engine for future requests."""
372-
await self.engine_core.add_lora_async(lora_request)
372+
return await self.engine_core.add_lora_async(lora_request)
373+
374+
async def remove_lora(self, lora_id: int) -> bool:
375+
"""Remove an already loaded LoRA adapter."""
376+
return await self.engine_core.remove_lora_async(lora_id)
377+
378+
async def list_loras(self) -> Set[int]:
379+
"""List all registered adapters."""
380+
return await self.engine_core.list_loras_async()
381+
382+
async def pin_lora(self, lora_id: int) -> bool:
383+
"""Prevent an adapter from being evicted."""
384+
return await self.engine_core.pin_lora_async(lora_id)
373385

374386
@property
375387
def is_running(self) -> bool:

vllm/v1/engine/core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from concurrent.futures import Future
88
from inspect import isclass, signature
99
from multiprocessing.connection import Connection
10-
from typing import Any, List, Optional, Tuple, Type
10+
from typing import Any, List, Optional, Tuple, Type, Set
1111

1212
import msgspec
1313
import psutil
@@ -222,8 +222,17 @@ def wake_up(self):
222222
def execute_dummy_batch(self):
223223
self.model_executor.collective_rpc("execute_dummy_batch")
224224

225-
def add_lora(self, lora_request: LoRARequest) -> None:
226-
self.model_executor.add_lora(lora_request)
225+
def add_lora(self, lora_request: LoRARequest) -> bool:
226+
return self.model_executor.add_lora(lora_request)
227+
228+
def remove_lora(self, lora_id: int) -> bool:
229+
return self.model_executor.remove_lora(lora_id)
230+
231+
def list_loras(self) -> Set[int]:
232+
return self.model_executor.list_loras()
233+
234+
def pin_lora(self, lora_id: int) -> bool:
235+
return self.model_executor.pin_lora(lora_id)
227236

228237

229238
class EngineCoreProc(EngineCore):

vllm/v1/engine/core_client.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from concurrent.futures import Future
1111
from threading import Thread
12-
from typing import Any, Dict, List, Optional, Type, Union
12+
from typing import Any, Dict, List, Optional, Set, Type, Union
1313

1414
import zmq
1515
import zmq.asyncio
@@ -96,7 +96,16 @@ async def execute_dummy_batch_async(self) -> None:
9696
def abort_requests(self, request_ids: List[str]) -> None:
9797
raise NotImplementedError
9898

99-
def add_lora(self, lora_request: LoRARequest) -> None:
99+
def add_lora(self, lora_request: LoRARequest) -> bool:
100+
raise NotImplementedError
101+
102+
def remove_lora(self, lora_id: int) -> bool:
103+
raise NotImplementedError
104+
105+
def list_loras(self) -> Set[int]:
106+
raise NotImplementedError
107+
108+
def pin_lora(self, lora_id: int) -> bool:
100109
raise NotImplementedError
101110

102111
async def get_output_async(self) -> EngineCoreOutputs:
@@ -120,7 +129,16 @@ async def wake_up_async(self) -> None:
120129
async def abort_requests_async(self, request_ids: List[str]) -> None:
121130
raise NotImplementedError
122131

123-
async def add_lora_async(self, lora_request: LoRARequest) -> None:
132+
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
133+
raise NotImplementedError
134+
135+
async def remove_lora_async(self, lora_id: int) -> bool:
136+
raise NotImplementedError
137+
138+
async def list_loras_async(self) -> Set[int]:
139+
raise NotImplementedError
140+
141+
async def pin_lora_async(self, lora_id: int) -> bool:
124142
raise NotImplementedError
125143

126144

@@ -165,8 +183,17 @@ def wake_up(self) -> None:
165183
def execute_dummy_batch(self) -> None:
166184
self.engine_core.execute_dummy_batch()
167185

168-
def add_lora(self, lora_request: LoRARequest) -> None:
169-
self.engine_core.add_lora(lora_request)
186+
def add_lora(self, lora_request: LoRARequest) -> bool:
187+
return self.engine_core.add_lora(lora_request)
188+
189+
def remove_lora(self, lora_id: int) -> bool:
190+
return self.engine_core.remove_lora(lora_id)
191+
192+
def list_loras(self) -> Set[int]:
193+
return self.engine_core.list_loras()
194+
195+
def pin_lora(self, lora_id: int) -> bool:
196+
return self.engine_core.pin_lora(lora_id)
170197

171198

172199
class MPClient(EngineCoreClient):
@@ -331,8 +358,17 @@ def profile(self, is_start: bool = True) -> None:
331358
def reset_prefix_cache(self) -> None:
332359
self._call_utility("reset_prefix_cache")
333360

334-
def add_lora(self, lora_request: LoRARequest) -> None:
335-
self._call_utility("add_lora", lora_request)
361+
def add_lora(self, lora_request: LoRARequest) -> bool:
362+
return self._call_utility("add_lora", lora_request)
363+
364+
def remove_lora(self, lora_id: int) -> bool:
365+
return self._call_utility("remove_lora", lora_id)
366+
367+
def list_loras(self) -> Set[int]:
368+
return self._call_utility("list_loras")
369+
370+
def pin_lora(self, lora_id: int) -> bool:
371+
return self._call_utility("pin_lora", lora_id)
336372

337373
def sleep(self, level: int = 1) -> None:
338374
self._call_utility("sleep", level)
@@ -429,5 +465,14 @@ async def wake_up_async(self) -> None:
429465
async def execute_dummy_batch_async(self) -> None:
430466
await self._call_utility_async("execute_dummy_batch")
431467

432-
async def add_lora_async(self, lora_request: LoRARequest) -> None:
433-
await self._call_utility_async("add_lora", lora_request)
468+
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
469+
return await self._call_utility_async("add_lora", lora_request)
470+
471+
async def remove_lora_async(self, lora_id: int) -> bool:
472+
return await self._call_utility_async("remove_lora", lora_id)
473+
474+
async def list_loras_async(self) -> Set[int]:
475+
return await self._call_utility_async("list_loras")
476+
477+
async def pin_lora_async(self, lora_id: int) -> bool:
478+
return await self._call_utility_async("pin_lora", lora_id)

vllm/v1/engine/llm_engine.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Dict, List, Mapping, Optional, Type, Union
3+
from typing import Dict, List, Mapping, Optional, Type, Union, Set
44

55
from typing_extensions import TypeVar
66

@@ -217,3 +217,19 @@ def get_tokenizer_group(
217217
f"found type: {type(tokenizer_group)}")
218218

219219
return tokenizer_group
220+
221+
def add_lora(self, lora_request: LoRARequest) -> bool:
222+
"""Load a new LoRA adapter into the engine for future requests."""
223+
return self.engine_core.add_lora(lora_request)
224+
225+
def remove_lora(self, lora_id: int) -> bool:
226+
"""Remove an already loaded LoRA adapter."""
227+
return self.engine_core.remove_lora(lora_id)
228+
229+
def list_loras(self) -> Set[int]:
230+
"""List all registered adapters."""
231+
return self.engine_core.list_loras()
232+
233+
def pin_lora(self, lora_id: int) -> bool:
234+
"""Prevent an adapter from being evicted."""
235+
return self.engine_core.pin_lora(lora_id)

vllm/v1/worker/gpu_worker.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""A GPU worker class."""
33
import gc
44
import os
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, Optional, Set
66

77
import torch
88
import torch.distributed
@@ -240,6 +240,15 @@ def execute_dummy_batch(self) -> None:
240240
def add_lora(self, lora_request: LoRARequest) -> bool:
241241
return self.model_runner.add_lora(lora_request)
242242

243+
def remove_lora(self, lora_id: int) -> bool:
244+
return self.model_runner.remove_lora(lora_id)
245+
246+
def list_loras(self) -> Set[int]:
247+
return self.model_runner.list_loras()
248+
249+
def pin_lora(self, lora_id: int) -> bool:
250+
return self.model_runner.pin_lora(lora_id)
251+
243252
def check_health(self) -> None:
244253
# worker will always be healthy as long as it's running.
245254
return

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,19 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig,
131131
def add_lora(self, lora_request: LoRARequest) -> bool:
132132
if not self.lora_manager:
133133
raise RuntimeError("LoRA is not enabled.")
134-
return self.lora_manager.add_adapter(lora_request)
134+
return self.lora_manager.add_adapter(lora_request)
135+
136+
def remove_lora(self, lora_id: int) -> bool:
137+
if not self.lora_manager:
138+
raise RuntimeError("LoRA is not enabled.")
139+
return self.lora_manager.remove_adapter(lora_id)
140+
141+
def pin_lora(self, lora_id: int) -> bool:
142+
if not self.lora_manager:
143+
raise RuntimeError("LoRA is not enabled.")
144+
return self.lora_manager.pin_adapter(lora_id)
145+
146+
def list_loras(self) -> Set[int]:
147+
if not self.lora_manager:
148+
raise RuntimeError("LoRA is not enabled.")
149+
return self.lora_manager.list_adapters()

0 commit comments

Comments
 (0)