Skip to content

Commit 37279aa

Browse files
vanbasten23sierraisland
authored andcommitted
Add lora layer tests (#981)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 4688f34 commit 37279aa

File tree

5 files changed

+591
-6
lines changed

5 files changed

+591
-6
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ steps:
151151
exit 0
152152
fi
153153
154-
- label: "lora tests for JAX + vLLM models single chip"
154+
- label: "lora e2e tests for JAX + vLLM models single chip"
155155
key: test_10
156156
soft_fail: true
157157
agents:
@@ -160,8 +160,7 @@ steps:
160160
- |
161161
if [[ "$$NIGHTLY" == "1" ]]; then
162162
.buildkite/scripts/run_in_docker.sh \
163-
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
164-
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py'
163+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py'
165164
else
166165
echo "Skipping: NIGHTLY environment variable not set"
167166
exit 0
@@ -203,7 +202,7 @@ steps:
203202
exit 0
204203
fi
205204
206-
- label: "lora tests for JAX + vLLM models multi chips"
205+
- label: "lora e2e tests for JAX + vLLM models multi chips"
207206
key: test_13
208207
soft_fail: true
209208
env:
@@ -233,6 +232,29 @@ steps:
233232
.buildkite/scripts/run_in_docker.sh \
234233
bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/e2e/test_data_parallel.py'
235234
235+
- label: "lora unit tests on single chip"
236+
key: test_15
237+
soft_fail: true
238+
agents:
239+
queue: tpu_v6e_queue
240+
commands:
241+
- |
242+
.buildkite/scripts/run_in_docker.sh \
243+
bash -c ' python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \
244+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
245+
246+
- label: "lora unit tests on multi chips"
247+
key: test_16
248+
soft_fail: true
249+
env:
250+
USE_V6E8_QUEUE: "True"
251+
VLLM_LOG_LEVEL: "INFO"
252+
agents:
253+
queue: tpu_v6e_8_queue
254+
commands:
255+
- |
256+
.buildkite/scripts/run_in_docker.sh \
257+
bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
236258
# -----------------------------------------------------------------
237259
# NOTIFICATION STEP
238260
# -----------------------------------------------------------------
@@ -253,9 +275,11 @@ steps:
253275
- test_12
254276
- test_13
255277
- test_14
278+
- test_15
279+
- test_16
256280
agents:
257281
queue: cpu
258282
commands:
259283
- |
260284
.buildkite/scripts/check_results.sh \
261-
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13
285+
"TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_14 test_15 test_16

tests/lora/conftest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tempfile
2+
3+
import pytest
4+
from vllm.config import set_current_vllm_config
5+
from vllm.distributed import cleanup_dist_env_and_memory
6+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
7+
init_distributed_environment)
8+
from vllm.engine.arg_utils import EngineArgs
9+
10+
11+
@pytest.fixture
12+
def dist_init():
13+
engine_args = EngineArgs(
14+
model="Qwen/Qwen2-1.5B-Instruct",
15+
max_model_len=64,
16+
max_num_batched_tokens=64,
17+
max_num_seqs=4,
18+
)
19+
20+
vllm_config = engine_args.create_engine_config()
21+
22+
with set_current_vllm_config(vllm_config):
23+
temp_file = tempfile.mkstemp()[1]
24+
init_distributed_environment(
25+
1,
26+
0,
27+
local_rank=0,
28+
distributed_init_method=f"file://{temp_file}",
29+
backend="gloo")
30+
ensure_model_parallel_initialized(1, 1)
31+
yield vllm_config
32+
cleanup_dist_env_and_memory(shutdown_ray=True)

0 commit comments

Comments
 (0)