Skip to content

Commit 9a32eeb

Browse files
committed
implement runai model streamer for MODEL_IMPL_TYPE=flax_nnx
1 parent 2392503 commit 9a32eeb

File tree

5 files changed

+401
-143
lines changed

5 files changed

+401
-143
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# runai_model_streamer_loader
2+
steps:
3+
- label: "Correctness tests for runai_model_streamer_loader"
4+
key: "runai_model_streamer_loader_CorrectnessTest"
5+
soft_fail: true
6+
agents:
7+
queue: tpu_v6e_queue
8+
commands:
9+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness
10+
- label: "Record correctness test result for runai_model_streamer_loader"
11+
key: "record_runai_model_streamer_loader_CorrectnessTest"
12+
depends_on: "runai_model_streamer_loader_CorrectnessTest"
13+
env:
14+
CI_TARGET: "runai_model_streamer_loader"
15+
CI_STAGE: "CorrectnessTest"
16+
agents:
17+
queue: cpu
18+
commands:
19+
- |
20+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest
21+
22+
- label: "Performance tests for runai_model_streamer_loader"
23+
key: "runai_model_streamer_loader_PerformanceTest"
24+
depends_on: "record_runai_model_streamer_loader_CorrectnessTest"
25+
soft_fail: true
26+
agents:
27+
queue: tpu_v6e_queue
28+
commands:
29+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_performance
30+
- label: "Record performance test result for runai_model_streamer_loader"
31+
key: "record_runai_model_streamer_loader_PerformanceTest"
32+
depends_on: "runai_model_streamer_loader_PerformanceTest"
33+
env:
34+
CI_TARGET: "runai_model_streamer_loader"
35+
CI_STAGE: "PerformanceTest"
36+
agents:
37+
queue: cpu
38+
commands:
39+
- |
40+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_PerformanceTest

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ torchvision==0.23.0
1616
pathwaysutils
1717
parameterized
1818
numba==0.62.1
19+
runai-model-streamer[s3,gcs]==0.15.0
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import pytest
6+
from vllm import LLM, SamplingParams
7+
8+
9+
@pytest.fixture
10+
def sampling_config():
11+
return SamplingParams(temperature=0,
12+
max_tokens=10,
13+
ignore_eos=True)
14+
15+
16+
@pytest.fixture
17+
# TODO(amacaskill): Replace with GKE owned GCS bucket, and a smaller model.
18+
def gcs_model_name():
19+
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
20+
21+
22+
@pytest.fixture
23+
def hf_model_name():
24+
return "meta-llama/Meta-Llama-3-8B"
25+
26+
27+
@pytest.fixture
28+
def prompt():
29+
return "Hello, my name is"
30+
31+
32+
def test_correctness(
33+
sampling_config: SamplingParams,
34+
gcs_model_name: str,
35+
hf_model_name: str,
36+
prompt: str,
37+
):
38+
'''
39+
Compare the outputs of a model loaded from GCS via runai_model_streamer
40+
and a model loaded from Hugging Face. The outputs should be the same.
41+
'''
42+
# Test with GCS model using runai_model_streamer
43+
gcs_llm = LLM(model=gcs_model_name,
44+
model_impl_type="runai_model_streamer",
45+
max_model_len=128,
46+
max_num_seqs=16,
47+
max_num_batched_tokens=256)
48+
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
49+
gcs_output_text = gcs_outputs[0].outputs[0].text
50+
del gcs_llm
51+
time.sleep(10) # Wait for TPUs to be released
52+
53+
# Test with Hugging Face model
54+
hf_llm = LLM(model=hf_model_name,
55+
max_model_len=128,
56+
max_num_seqs=16,
57+
max_num_batched_tokens=256)
58+
hf_outputs = hf_llm.generate([prompt], sampling_config)
59+
hf_output_text = hf_outputs[0].outputs[0].text
60+
del hf_llm
61+
time.sleep(10) # Wait for TPUs to be released
62+
63+
assert gcs_output_text == hf_output_text, (
64+
f"Outputs do not match! "
65+
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}"
66+
)
67+
68+
69+
def test_performance(
70+
gcs_model_name: str,
71+
hf_model_name: str,
72+
):
73+
'''
74+
Compare the model load time of a model loaded from GCS via
75+
runai_model_streamer and a model loaded from Hugging Face.
76+
'''
77+
# Time loading from GCS
78+
start_time = time.time()
79+
gcs_llm = LLM(model=gcs_model_name,
80+
model_impl_type="runai_model_streamer",
81+
max_model_len=128,
82+
max_num_seqs=16,
83+
max_num_batched_tokens=256)
84+
gcs_load_time = time.time() - start_time
85+
print(f"GCS model load time: {gcs_load_time:.2f} seconds")
86+
del gcs_llm
87+
time.sleep(10)
88+
89+
# Time loading from Hugging Face
90+
start_time = time.time()
91+
hf_llm = LLM(model=hf_model_name,
92+
max_model_len=128,
93+
max_num_seqs=16,
94+
max_num_batched_tokens=256)
95+
hf_load_time = time.time() - start_time
96+
print(f"Hugging Face model load time: {hf_load_time:.2f} seconds")
97+
del hf_llm
98+
time.sleep(10)
99+
100+
print(f"GCS load time: {gcs_load_time:.2f}s, HF load time: {hf_load_time:.2f}s")

tpu_inference/models/common/model_loader.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from transformers import PretrainedConfig
1111
from vllm.config import VllmConfig
1212
from vllm.utils.func_utils import supports_kw
13+
from vllm.model_executor.model_loader import get_model_loader
14+
from vllm.model_executor.model_loader.runai_streamer_loader import RunaiModelStreamerLoader
15+
1316

1417
from tpu_inference.layers.jax.sharding import ShardingAxisName
1518
from tpu_inference.logger import init_logger
@@ -177,7 +180,22 @@ def create_sharded_model():
177180
# the model creation again, otherwise the model forward will have
178181
# non-trivial overhead in PjitFunction.
179182
with mesh:
180-
model.load_weights(rng)
183+
loader = get_model_loader(vllm_config.load_config)
184+
if isinstance(loader, RunaiModelStreamerLoader):
185+
model_weights = vllm_config.model_config.model
186+
if hasattr(vllm_config.model_config, "model_weights"):
187+
model_weights = vllm_config.model_config.model_weights
188+
weights_iterator = loader._get_weights_iterator(model_weights, vllm_config.model_config.revision)
189+
# We set the weights iterator at runtime, to prevent having to change
190+
# every model's load_weights signature. This also prevents us from hitting
191+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
192+
# flax_nnx model whose load_weights function does not accept the
193+
# weights_iterator keyword argument.
194+
vllm_config.model_config.model_weights_iterator = weights_iterator
195+
model.load_weights(rng)
196+
del vllm_config.model_config.model_weights_iterator
197+
else:
198+
model.load_weights(rng)
181199
jit_model = create_jit_model(
182200
model,
183201
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)

0 commit comments

Comments
 (0)