Skip to content

Commit a17138c

Browse files
committed
implement runai model streamer for MODEL_IMPL_TYPE=flax_nnx
Signed-off-by: Alexis MacAskill <amacaskill@google.com>
1 parent 6b8b654 commit a17138c

File tree

5 files changed

+382
-143
lines changed

5 files changed

+382
-143
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# runai_model_streamer_loader
2+
steps:
3+
- label: "Correctness tests for runai_model_streamer_loader"
4+
key: "runai_model_streamer_loader_CorrectnessTest"
5+
soft_fail: true
6+
agents:
7+
queue: tpu_v6e_queue
8+
commands:
9+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness
10+
- label: "Record correctness test result for runai_model_streamer_loader"
11+
key: "record_runai_model_streamer_loader_CorrectnessTest"
12+
depends_on: "runai_model_streamer_loader_CorrectnessTest"
13+
env:
14+
CI_TARGET: "runai_model_streamer_loader"
15+
CI_STAGE: "CorrectnessTest"
16+
agents:
17+
queue: cpu
18+
commands:
19+
- |
20+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest
21+
22+
- label: "Performance tests for runai_model_streamer_loader"
23+
key: "runai_model_streamer_loader_PerformanceTest"
24+
depends_on: "record_runai_model_streamer_loader_CorrectnessTest"
25+
soft_fail: true
26+
agents:
27+
queue: tpu_v6e_queue
28+
commands:
29+
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_performance
30+
- label: "Record performance test result for runai_model_streamer_loader"
31+
key: "record_runai_model_streamer_loader_PerformanceTest"
32+
depends_on: "runai_model_streamer_loader_PerformanceTest"
33+
env:
34+
CI_TARGET: "runai_model_streamer_loader"
35+
CI_STAGE: "PerformanceTest"
36+
agents:
37+
queue: cpu
38+
commands:
39+
- |
40+
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_PerformanceTest

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ torchvision==0.23.0
1616
pathwaysutils
1717
parameterized
1818
numba==0.62.1
19+
runai-model-streamer[s3,gcs]==0.15.0
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import pytest
6+
from vllm import LLM, SamplingParams
7+
8+
9+
@pytest.fixture
10+
def sampling_config():
11+
return SamplingParams(temperature=0,
12+
max_tokens=10,
13+
ignore_eos=True)
14+
15+
16+
@pytest.fixture
17+
# TODO(amacaskill): Replace with GKE owned GCS bucket, and a smaller model.
18+
def gcs_model_name():
19+
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
20+
21+
22+
@pytest.fixture
23+
def hf_model_name():
24+
return "meta-llama/Meta-Llama-3-8B"
25+
26+
27+
@pytest.fixture
28+
def prompt():
29+
return "Hello, my name is"
30+
31+
32+
def test_correctness(
33+
sampling_config: SamplingParams,
34+
gcs_model_name: str,
35+
hf_model_name: str,
36+
prompt: str,
37+
monkeypatch: pytest.MonkeyPatch,
38+
):
39+
'''
40+
Compare the outputs of a model loaded from GCS via runai_model_streamer
41+
and a model loaded from Hugging Face. The outputs should be the same.
42+
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
43+
# and v6e has 32GB of HBM, so it should fit.
44+
'''
45+
gcs_llm = LLM(model=gcs_model_name,
46+
load_format="runai_streamer",
47+
max_model_len=128,
48+
max_num_seqs=16,
49+
max_num_batched_tokens=256)
50+
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
51+
gcs_output_text = gcs_outputs[0].outputs[0].text
52+
del gcs_llm
53+
time.sleep(10) # Wait for TPUs to be released
54+
55+
# Test with Hugging Face model
56+
hf_llm = LLM(model=hf_model_name,
57+
max_model_len=128,
58+
max_num_seqs=16,
59+
max_num_batched_tokens=256)
60+
hf_outputs = hf_llm.generate([prompt], sampling_config)
61+
hf_output_text = hf_outputs[0].outputs[0].text
62+
del hf_llm
63+
time.sleep(10) # Wait for TPUs to be released
64+
65+
assert gcs_output_text == hf_output_text, (
66+
f"Outputs do not match! "
67+
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}"
68+
)
69+
70+
71+
def test_performance(
72+
gcs_model_name: str,
73+
hf_model_name: str,
74+
monkeypatch: pytest.MonkeyPatch,
75+
):
76+
'''
77+
Compare the model load time of a model loaded from GCS via
78+
runai_model_streamer and a model loaded from Hugging Face.
79+
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
80+
and v6e has 32GB of HBM, so it should fit. This test will fail if
81+
the GCS load time is more than 100% higher than the Hugging Face load time.
82+
The model load time for HF vs RunAI model streamer is comparable for smaller
83+
models, but for larger models the RunAI model streamer should be significantly
84+
faster; however, we don't really want to test larger models since it will take
85+
a while to run.
86+
'''
87+
# Time loading from GCS
88+
start_time = time.time()
89+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
90+
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
91+
monkeypatch.setenv(
92+
"CLOUD_STORAGE_EMULATOR_ENDPOINT", "https://storage.googleapis.com"
93+
)
94+
gcs_llm = LLM(model=gcs_model_name,
95+
load_format="runai_streamer",
96+
max_model_len=128,
97+
max_num_seqs=16,
98+
max_num_batched_tokens=256)
99+
gcs_load_time = time.time() - start_time
100+
print(f"RunAI model load time: {gcs_load_time:.2f} seconds")
101+
del gcs_llm
102+
time.sleep(10)
103+
104+
# Time loading from Hugging Face
105+
start_time = time.time()
106+
hf_llm = LLM(model=hf_model_name,
107+
max_model_len=128,
108+
max_num_seqs=16,
109+
max_num_batched_tokens=256)
110+
hf_load_time = time.time() - start_time
111+
print(f"Hugging Face model load time: {hf_load_time:.2f} seconds")
112+
del hf_llm
113+
time.sleep(10)
114+
115+
print(f"RunAI model load time: {gcs_load_time:.2f}s, "
116+
f"HF model load time: {hf_load_time:.2f}s")
117+
assert gcs_load_time < hf_load_time * 2, (
118+
f"RunAI model load time ({gcs_load_time:.2f}s) is more than 100% higher "
119+
f"than Hugging Face model load time ({hf_load_time:.2f}s)."
120+
)

tpu_inference/models/common/model_loader.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from transformers import PretrainedConfig
1111
from vllm.config import VllmConfig
1212
from vllm.utils.func_utils import supports_kw
13+
from vllm.model_executor.model_loader import get_model_loader
14+
from vllm.model_executor.model_loader.runai_streamer_loader import RunaiModelStreamerLoader
15+
1316

1417
from tpu_inference.layers.jax.sharding import ShardingAxisName
1518
from tpu_inference.logger import init_logger
@@ -177,7 +180,22 @@ def create_sharded_model():
177180
# the model creation again, otherwise the model forward will have
178181
# non-trivial overhead in PjitFunction.
179182
with mesh:
180-
model.load_weights(rng)
183+
loader = get_model_loader(vllm_config.load_config)
184+
if isinstance(loader, RunaiModelStreamerLoader):
185+
model_weights = vllm_config.model_config.model
186+
if hasattr(vllm_config.model_config, "model_weights"):
187+
model_weights = vllm_config.model_config.model_weights
188+
weights_iterator = loader._get_weights_iterator(model_weights, vllm_config.model_config.revision)
189+
# We set the weights iterator at runtime, to prevent having to change
190+
# every model's load_weights signature. This also prevents us from hitting
191+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
192+
# flax_nnx model whose load_weights function does not accept the
193+
# weights_iterator keyword argument.
194+
vllm_config.model_config.model_weights_iterator = weights_iterator
195+
model.load_weights(rng)
196+
del vllm_config.model_config.model_weights_iterator
197+
else:
198+
model.load_weights(rng)
181199
jit_model = create_jit_model(
182200
model,
183201
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)

0 commit comments

Comments
 (0)