diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 6e1f79ae649e..a93b79c0b1b2 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -32,7 +32,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_5 \ && python3 /workspace/vllm/examples/offline_inference/tpu.py \ && echo TEST_6 \ - && pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \ + && pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \ + && echo TEST_7 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 4e5a57bee327..f535abedea22 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -1,7 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import tempfile -from time import time - import pytest from vllm import LLM, envs @@ -15,60 +12,6 @@ ) -@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") -def test_sampler_compilation(model_name: str, monkeypatch): - """ - Check that no recompilation happens despite changing sampling parameters. - We can't read XLA metrics from the engine process, hence we measure time. - """ - with tempfile.TemporaryDirectory() as temp_dir: - monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir) - # Compiling model init may still take some time, enforce_eager to skip. - llm = LLM(model_name, - enforce_eager=True, - max_num_seqs=16, - max_model_len=1024, - gpu_memory_utilization=0.5) - prompts = [ - "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - ] - # First inference should be slow - sampling_params = SamplingParams( - temperature=0.7, - # top_p=0.6, # TODO too slow! - top_k=10, - min_p=0.2, - max_tokens=16) - s = time() - _ = llm.generate(prompts, sampling_params) - run1 = time() - s - - # Second request with different params, but for which we - # compiled for in previous eager iteration. - sampling_params = SamplingParams(temperature=0.1, - top_k=12, - min_p=0.8, - max_tokens=24) - s = time() - _ = llm.generate(prompts, sampling_params) - run2 = time() - s - # Much faster after compiling - assert run1 * 0.1 > run2 - print("TIMES", run1, run2) - - # Third request with min_p set to "None". It will not trigger - # recompilation as a default 0 value will be used. - sampling_params = SamplingParams(max_tokens=24, temperature=0.0) - s = time() - _ = llm.generate(prompts, sampling_params) - run3 = time() - s - assert run1 * 0.1 > run3 - print("TIMES", run1, run3) - - @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) @pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") @@ -77,13 +20,11 @@ def test_sampler_different(model_name: str): Test significantly different sampling params to assert the model produces different results. """ - llm = LLM( - model_name, - enforce_eager=True, - max_num_seqs=1, - max_model_len=64, - # TODO: setting to 0.5 or it will go OOM - gpu_memory_utilization=0.5) + llm = LLM(model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=512, + max_num_batched_tokens=512) prompts = [ "Write a short story about a robot that dreams for the first time." ] diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index d605c4b65e9d..89d3ddf51d74 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -88,6 +88,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, # Copy slice from CPU to corresponding TPU pre-allocated tensor. # Pad value is the default one. cpu_tensor[num_reqs:padded_num_reqs] = fill_val + # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs` tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs] # NOTE NickLucche The sync CPU-TPU graph we produce here must be @@ -101,13 +102,6 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, DEFAULT_SAMPLING_PARAMS["min_p"]) - # copy_slice(input_batch.frequency_penalties_cpu_tensor, - # input_batch.frequency_penalties) - # copy_slice(input_batch.presence_penalties_cpu_tensor, - # input_batch.presence_penalties) - # copy_slice(input_batch.repetition_penalties_cpu_tensor, - # input_batch.repetition_penalties) - xm.mark_step() xm.wait_device_ops() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 65a4048ae74d..3df555d602cb 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -88,6 +88,8 @@ def __init__( self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens + # InputBatch needs to work with sampling tensors greater than padding + # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) # Model-related. @@ -787,6 +789,7 @@ def capture_model(self) -> None: dummy_hidden = torch.randn((num_tokens, hsize), device=device, dtype=torch.bfloat16) + # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] while True: indices = torch.zeros( num_reqs_to_sample, @@ -803,7 +806,9 @@ def capture_model(self) -> None: out = out.cpu() if num_reqs_to_sample >= self.max_num_reqs: break - num_reqs_to_sample *= 2 + # Make sure to compile the `max_num_reqs` upper-limit case + num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit( + num_reqs_to_sample + 1, self.max_num_reqs) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) @@ -896,7 +901,6 @@ def forward( return hidden_states - # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_hidden( self, hidden_states: torch.Tensor,