From 5c090542dda52ddd0eae6f5bd98ab35855836614 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 7 Apr 2025 08:41:14 +0000 Subject: [PATCH 1/6] disable per-request generators Signed-off-by: NickLucche --- tests/v1/tpu/test_sampler.py | 5 ++++- vllm/v1/sample/tpu/metadata.py | 16 ++++++++++------ vllm/v1/worker/tpu_model_runner.py | 8 +++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index f535abedea22..a6b45043910f 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -28,7 +28,10 @@ def test_sampler_different(model_name: str): prompts = [ "Write a short story about a robot that dreams for the first time." ] - sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) + sampling_params = SamplingParams(temperature=0.9, + min_p=0.2, + max_tokens=64, + seed=42) output = llm.generate(prompts, sampling_params) sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 10995d6787a5..3950fda3e5ea 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata: # Greedy sampling flag for compiling single xla graph. all_greedy: bool = True - # Generator not supported by xla - generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) - # unsupported, you need to return an extra tensor of static size BxV max_num_logprobs = None @@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata: allowed_token_ids_mask = None bad_words_token_ids = None + # Generator not supported by xla + _generators: dict[int, + torch.Generator] = field(default_factory=lambda: dict()) + + @property + def generators(self) -> dict[int, torch.Generator]: + # Generator not supported by torch/xla. This field must be immutable. + return self._generators + @classmethod def from_input_batch( cls, @@ -109,5 +114,4 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: top_p=None, # input_batch.top_p[:padded_num_reqs], top_k=None, # input_batch.top_k[:padded_num_reqs], min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - generators=input_batch.generators) + xla_device)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 773c426474fc..8111d4dd5b70 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -268,10 +268,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None + logger.warning("Torch XLA does not support per-request seed." + "Seed {sampling_params.seed} will be ignored") self.requests[req_id] = CachedRequestState( req_id=req_id, @@ -280,7 +278,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, - generator=generator, + generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], From 88363fc87276b53e08a4e0252fd2b30a0095936b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 7 Apr 2025 14:59:52 +0000 Subject: [PATCH 2/6] re-add test Signed-off-by: NickLucche --- .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 87f74277cf90..eb58bb53c66d 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -40,8 +40,6 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_8 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ - - -# TODO: This test fails because it uses RANDOM_SEED sampling -# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ + && echo TEST_10 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \ \ No newline at end of file From 42aaca7300d1b198e9f52436f0fefe61c55c46c4 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 8 Apr 2025 07:20:57 +0000 Subject: [PATCH 3/6] raise error to client when pre-request seed is set Signed-off-by: NickLucche --- .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh | 2 +- tests/v1/tpu/test_sampler.py | 10 ++++++---- vllm/v1/engine/processor.py | 6 +++++- vllm/v1/worker/tpu_model_runner.py | 4 ---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index eb58bb53c66d..ee6202979032 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -42,4 +42,4 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_9 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ && echo TEST_10 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \ \ No newline at end of file + && pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \ diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index a6b45043910f..0147da533517 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -28,12 +28,14 @@ def test_sampler_different(model_name: str): prompts = [ "Write a short story about a robot that dreams for the first time." ] - sampling_params = SamplingParams(temperature=0.9, - min_p=0.2, - max_tokens=64, - seed=42) + sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) output = llm.generate(prompts, sampling_params) sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) output2 = llm.generate(prompts, sampling_params) assert output[0].outputs[0].text != output2[0].outputs[0].text + + with pytest.raises(ValueError): + # Unsupported `seed` param. + sampling_params = SamplingParams(temperature=0.3, seed=42) + output2 = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7d1913ecebed..7ab3b5ee1c32 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -14,9 +14,10 @@ from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import merge_and_sort_multimodal_metadata +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache @@ -77,6 +78,9 @@ def _validate_sampling_params( params: SamplingParams, ) -> None: self._validate_structured_output(params) + if (current_platform.is_tpu() + and params.sampling_type == SamplingType.RANDOM_SEED): + raise ValueError("Torch XLA does not support per-request seed.") if params.allowed_token_ids is None: return diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8111d4dd5b70..111bb5e911e1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,7 +23,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, @@ -267,9 +266,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - logger.warning("Torch XLA does not support per-request seed." - "Seed {sampling_params.seed} will be ignored") self.requests[req_id] = CachedRequestState( req_id=req_id, From 373333955c1af84a99f67354a489ee48c380ff46 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 10 Apr 2025 07:45:07 +0000 Subject: [PATCH 4/6] move seed check to platform specific validate Signed-off-by: NickLucche --- vllm/platforms/tpu.py | 4 +++- vllm/v1/engine/processor.py | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 61e84a6d6f95..62bf900f76ef 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm.inputs import PromptType from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from .interface import Platform, PlatformEnum, _Backend @@ -149,3 +149,5 @@ def validate_request( SamplingParams) and params.guided_decoding is not None: raise ValueError("Structured output is not supported on " f"{cls.device_name}.") + if params.sampling_type == SamplingType.RANDOM_SEED: + raise ValueError("Torch XLA does not support per-request seed.") diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7ab3b5ee1c32..7d1913ecebed 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -14,10 +14,9 @@ from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import merge_and_sort_multimodal_metadata -from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache @@ -78,9 +77,6 @@ def _validate_sampling_params( params: SamplingParams, ) -> None: self._validate_structured_output(params) - if (current_platform.is_tpu() - and params.sampling_type == SamplingType.RANDOM_SEED): - raise ValueError("Torch XLA does not support per-request seed.") if params.allowed_token_ids is None: return From 98575862d675bd377125d8d9a74e0371b13a6e9f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 10 Apr 2025 16:38:44 +0000 Subject: [PATCH 5/6] revert add test Signed-off-by: NickLucche --- .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index ee6202979032..87f74277cf90 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -40,6 +40,8 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_8 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ - && echo TEST_10 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ + + +# TODO: This test fails because it uses RANDOM_SEED sampling +# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ From 9885fe4c1905f44465fda3872a54f024c7958827 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 10 Apr 2025 17:08:24 +0000 Subject: [PATCH 6/6] params type check Signed-off-by: NickLucche --- vllm/platforms/tpu.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 62bf900f76ef..ada599c27b44 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -145,9 +145,10 @@ def validate_request( params: Union[SamplingParams, PoolingParams], ) -> None: """Raises if this request is unsupported on this platform""" - if isinstance(params, - SamplingParams) and params.guided_decoding is not None: - raise ValueError("Structured output is not supported on " - f"{cls.device_name}.") - if params.sampling_type == SamplingType.RANDOM_SEED: - raise ValueError("Torch XLA does not support per-request seed.") + if isinstance(params, SamplingParams): + if params.guided_decoding is not None: + raise ValueError("Structured output is not supported on " + f"{cls.device_name}.") + if params.sampling_type == SamplingType.RANDOM_SEED: + raise ValueError( + "Torch XLA does not support per-request seed.")