Skip to content

Commit 0a7f8b2

Browse files
committed
implement runai model streamer for MODEL_IMPL_TYPE=flax_nnx
1 parent 6b8b654 commit 0a7f8b2

File tree

5 files changed

+408
-143
lines changed

5 files changed

+408
-143
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# runai_model_streamer_loader
2+
steps:
3+
- label: "Correctness tests for runai_model_streamer_loader"
4+
key: "runai_model_streamer_loader_CorrectnessTest"
5+
soft_fail: true
6+
agents:
7+
queue: tpu_v6e_queue
8+
commands:
9+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness
10+
- label: "Record correctness test result for runai_model_streamer_loader"
11+
key: "record_runai_model_streamer_loader_CorrectnessTest"
12+
depends_on: "runai_model_streamer_loader_CorrectnessTest"
13+
env:
14+
CI_TARGET: "runai_model_streamer_loader"
15+
CI_STAGE: "CorrectnessTest"
16+
agents:
17+
queue: cpu
18+
commands:
19+
- |
20+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest
21+
22+
- label: "Performance tests for runai_model_streamer_loader"
23+
key: "runai_model_streamer_loader_PerformanceTest"
24+
depends_on: "record_runai_model_streamer_loader_CorrectnessTest"
25+
soft_fail: true
26+
agents:
27+
queue: tpu_v6e_queue
28+
commands:
29+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_performance
30+
- label: "Record performance test result for runai_model_streamer_loader"
31+
key: "record_runai_model_streamer_loader_PerformanceTest"
32+
depends_on: "runai_model_streamer_loader_PerformanceTest"
33+
env:
34+
CI_TARGET: "runai_model_streamer_loader"
35+
CI_STAGE: "PerformanceTest"
36+
agents:
37+
queue: cpu
38+
commands:
39+
- |
40+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_PerformanceTest

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ torchvision==0.23.0
1616
pathwaysutils
1717
parameterized
1818
numba==0.62.1
19+
runai-model-streamer[s3,gcs]==0.15.0
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import pytest
6+
from vllm import LLM, SamplingParams
7+
8+
9+
@pytest.fixture
10+
def sampling_config():
11+
return SamplingParams(temperature=0,
12+
max_tokens=10,
13+
ignore_eos=True)
14+
15+
16+
@pytest.fixture
17+
# TODO(amacaskill): Replace with GKE owned GCS bucket, and a smaller model.
18+
def gcs_model_name():
19+
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
20+
21+
22+
@pytest.fixture
23+
def hf_model_name():
24+
return "meta-llama/Meta-Llama-3-8B"
25+
26+
27+
@pytest.fixture
28+
def prompt():
29+
return "Hello, my name is"
30+
31+
32+
def test_correctness(
33+
sampling_config: SamplingParams,
34+
gcs_model_name: str,
35+
hf_model_name: str,
36+
prompt: str,
37+
monkeypatch: pytest.MonkeyPatch,
38+
):
39+
'''
40+
Compare the outputs of a model loaded from GCS via runai_model_streamer
41+
and a model loaded from Hugging Face. The outputs should be the same.
42+
'''
43+
# Test with GCS model using runai_model_streamer
44+
gcs_llm = LLM(model=gcs_model_name,
45+
load_format="runai_streamer",
46+
max_model_len=128,
47+
max_num_seqs=16,
48+
max_num_batched_tokens=256)
49+
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
50+
gcs_output_text = gcs_outputs[0].outputs[0].text
51+
del gcs_llm
52+
time.sleep(10) # Wait for TPUs to be released
53+
54+
# Test with Hugging Face model
55+
hf_llm = LLM(model=hf_model_name,
56+
max_model_len=128,
57+
max_num_seqs=16,
58+
max_num_batched_tokens=256)
59+
hf_outputs = hf_llm.generate([prompt], sampling_config)
60+
hf_output_text = hf_outputs[0].outputs[0].text
61+
del hf_llm
62+
time.sleep(10) # Wait for TPUs to be released
63+
64+
assert gcs_output_text == hf_output_text, (
65+
f"Outputs do not match! "
66+
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}"
67+
)
68+
69+
70+
def test_performance(
71+
gcs_model_name: str,
72+
hf_model_name: str,
73+
monkeypatch: pytest.MonkeyPatch,
74+
):
75+
'''
76+
Compare the model load time of a model loaded from GCS via
77+
runai_model_streamer and a model loaded from Hugging Face.
78+
'''
79+
# Time loading from GCS
80+
start_time = time.time()
81+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
82+
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
83+
monkeypatch.setenv(
84+
"CLOUD_STORAGE_EMULATOR_ENDPOINT", "https://storage.googleapis.com"
85+
)
86+
gcs_llm = LLM(model=gcs_model_name,
87+
load_format="runai_streamer",
88+
max_model_len=128,
89+
max_num_seqs=16,
90+
max_num_batched_tokens=256)
91+
gcs_load_time = time.time() - start_time
92+
print(f"GCS model load time: {gcs_load_time:.2f} seconds")
93+
del gcs_llm
94+
time.sleep(10)
95+
96+
# Time loading from Hugging Face
97+
start_time = time.time()
98+
hf_llm = LLM(model=hf_model_name,
99+
max_model_len=128,
100+
max_num_seqs=16,
101+
max_num_batched_tokens=256)
102+
hf_load_time = time.time() - start_time
103+
print(f"Hugging Face model load time: {hf_load_time:.2f} seconds")
104+
del hf_llm
105+
time.sleep(10)
106+
107+
print(f"GCS load time: {gcs_load_time:.2f}s, HF load time: {hf_load_time:.2f}s")

tpu_inference/models/common/model_loader.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from transformers import PretrainedConfig
1111
from vllm.config import VllmConfig
1212
from vllm.utils.func_utils import supports_kw
13+
from vllm.model_executor.model_loader import get_model_loader
14+
from vllm.model_executor.model_loader.runai_streamer_loader import RunaiModelStreamerLoader
15+
1316

1417
from tpu_inference.layers.jax.sharding import ShardingAxisName
1518
from tpu_inference.logger import init_logger
@@ -177,7 +180,22 @@ def create_sharded_model():
177180
# the model creation again, otherwise the model forward will have
178181
# non-trivial overhead in PjitFunction.
179182
with mesh:
180-
model.load_weights(rng)
183+
loader = get_model_loader(vllm_config.load_config)
184+
if isinstance(loader, RunaiModelStreamerLoader):
185+
model_weights = vllm_config.model_config.model
186+
if hasattr(vllm_config.model_config, "model_weights"):
187+
model_weights = vllm_config.model_config.model_weights
188+
weights_iterator = loader._get_weights_iterator(model_weights, vllm_config.model_config.revision)
189+
# We set the weights iterator at runtime, to prevent having to change
190+
# every model's load_weights signature. This also prevents us from hitting
191+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
192+
# flax_nnx model whose load_weights function does not accept the
193+
# weights_iterator keyword argument.
194+
vllm_config.model_config.model_weights_iterator = weights_iterator
195+
model.load_weights(rng)
196+
del vllm_config.model_config.model_weights_iterator
197+
else:
198+
model.load_weights(rng)
181199
jit_model = create_jit_model(
182200
model,
183201
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)

0 commit comments

Comments
 (0)