Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .buildkite/features/runai_model_streamer_loader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# runai_model_streamer_loader
steps:
- label: "Correctness tests for runai_model_streamer_loader"
key: "runai_model_streamer_loader_CorrectnessTest"
soft_fail: true
agents:
queue: tpu_v6e_queue
commands:
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness
- label: "Record correctness test result for runai_model_streamer_loader"
key: "record_runai_model_streamer_loader_CorrectnessTest"
depends_on: "runai_model_streamer_loader_CorrectnessTest"
env:
CI_TARGET: "runai_model_streamer_loader"
CI_STAGE: "CorrectnessTest"
agents:
queue: cpu
commands:
- |
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest

- label: "Performance tests for runai_model_streamer_loader"
key: "runai_model_streamer_loader_PerformanceTest"
depends_on: "record_runai_model_streamer_loader_CorrectnessTest"
soft_fail: true
agents:
queue: tpu_v6e_queue
commands:
- .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_performance
- label: "Record performance test result for runai_model_streamer_loader"
key: "record_runai_model_streamer_loader_PerformanceTest"
depends_on: "runai_model_streamer_loader_PerformanceTest"
env:
CI_TARGET: "runai_model_streamer_loader"
CI_STAGE: "PerformanceTest"
agents:
queue: cpu
commands:
- |
.buildkite/scripts/record_step_result.sh runai_model_streamer_loader_PerformanceTest
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ torchvision==0.23.0
pathwaysutils
parameterized
numba==0.62.1
runai-model-streamer[s3,gcs]==0.15.0
75 changes: 75 additions & 0 deletions tests/e2e/test_runai_model_streamer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import time

import pytest
from vllm import LLM, SamplingParams


@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0,
max_tokens=10,
ignore_eos=True)


@pytest.fixture
# TODO(amacaskill): Replace with GKE owned GCS bucket.
def gcs_model_name():
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"


@pytest.fixture
def hf_model_name():
return "meta-llama/Meta-Llama-3-8B"


@pytest.fixture
def prompt():
return "Hello, my name is"


def test_correctness(
sampling_config: SamplingParams,
gcs_model_name: str,
hf_model_name: str,
prompt: str,
monkeypatch: pytest.MonkeyPatch,
):
'''
Compare the outputs of a model loaded from GCS via runai_model_streamer
and a model loaded from Hugging Face. The outputs should be the same.
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
# and v6e has 32GB of HBM, so it will fit.
'''
# Set ENV variables so that runai_model_streamer uses anonymous GCS access.
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
monkeypatch.setenv(
"CLOUD_STORAGE_EMULATOR_ENDPOINT", "https://storage.googleapis.com"
)
gcs_llm = LLM(model=gcs_model_name,
load_format="runai_streamer",
max_model_len=128,
max_num_seqs=16,
max_num_batched_tokens=256)
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
gcs_output_text = gcs_outputs[0].outputs[0].text
del gcs_llm
time.sleep(10) # Wait for TPUs to be released

# Test with Hugging Face model
hf_llm = LLM(model=hf_model_name,
max_model_len=128,
max_num_seqs=16,
max_num_batched_tokens=256)
hf_outputs = hf_llm.generate([prompt], sampling_config)
hf_output_text = hf_outputs[0].outputs[0].text
del hf_llm
time.sleep(10) # Wait for TPUs to be released

assert gcs_output_text == hf_output_text, (
f"Outputs do not match! "
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}"
)

20 changes: 19 additions & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.utils.func_utils import supports_kw
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.runai_streamer_loader import RunaiModelStreamerLoader


from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
Expand Down Expand Up @@ -177,7 +180,22 @@ def create_sharded_model():
# the model creation again, otherwise the model forward will have
# non-trivial overhead in PjitFunction.
with mesh:
model.load_weights(rng)
loader = get_model_loader(vllm_config.load_config)
if isinstance(loader, RunaiModelStreamerLoader):
model_weights = vllm_config.model_config.model
if hasattr(vllm_config.model_config, "model_weights"):
model_weights = vllm_config.model_config.model_weights
weights_iterator = loader._get_weights_iterator(model_weights, vllm_config.model_config.revision)
# We set the weights iterator at runtime, to prevent having to change
# every model's load_weights signature. This also prevents us from hitting
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
# flax_nnx model whose load_weights function does not accept the
# weights_iterator keyword argument.
vllm_config.model_config.model_weights_iterator = weights_iterator
model.load_weights(rng)
del vllm_config.model_config.model_weights_iterator
else:
model.load_weights(rng)
jit_model = create_jit_model(
model,
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
Expand Down
Loading